Skip to content

Commit 8cd811d

Browse files
Merge pull request #192 from daisybio/lto
Leave-tissue-out cross-validation and NaiveTissueMeanPredictor
2 parents 7e4bba4 + 560e90e commit 8cd811d

File tree

19 files changed

+855
-101
lines changed

19 files changed

+855
-101
lines changed

.github/workflows/run_tests.yml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ jobs:
1919
- { python-version: "3.12", os: ubuntu-latest, session: "pre-commit" }
2020
- { python-version: "3.12", os: ubuntu-latest, session: "mypy" }
2121
- { python-version: "3.12", os: ubuntu-latest, session: "tests" }
22-
- { python-version: "3.12", os: windows-latest, session: "tests" }
23-
- { python-version: "3.12", os: ubuntu-latest, session: "typeguard" }
22+
- { python-version: "3.12", os: windows-latest, session: "typeguard" }
2423
- { python-version: "3.12", os: ubuntu-latest, session: "xdoctest" }
2524
- { python-version: "3.12", os: ubuntu-latest, session: "docs-build" }
2625

@@ -98,10 +97,10 @@ jobs:
9897
- name: Check out the repository
9998
uses: actions/checkout@v4
10099

101-
- name: Set up Python 3.11
100+
- name: Set up Python 3.12
102101
uses: actions/setup-python@v5
103102
with:
104-
python-version: 3.11
103+
python-version: 3.12
105104

106105
- name: Install Poetry
107106
run: |

drevalpy/datasets/dataset.py

Lines changed: 98 additions & 42 deletions
Large diffs are not rendered by default.

drevalpy/datasets/loader.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from ..pipeline_function import pipeline_function
1010
from .curvecurator import fit_curves
1111
from .dataset import DrugResponseDataset
12-
from .utils import ALLOWED_MEASURES, CELL_LINE_IDENTIFIER, DRUG_IDENTIFIER, download_dataset
12+
from .utils import ALLOWED_MEASURES, CELL_LINE_IDENTIFIER, DRUG_IDENTIFIER, TISSUE_IDENTIFIER, download_dataset
1313

1414

1515
def check_measure(measure_queried: str, measures_data: list[str], dataset_name: str) -> None:
@@ -56,6 +56,7 @@ def load_gdsc1(
5656
response=response_data[measure].values,
5757
cell_line_ids=response_data[CELL_LINE_IDENTIFIER].values,
5858
drug_ids=response_data[DRUG_IDENTIFIER].values,
59+
tissues=response_data[TISSUE_IDENTIFIER].values,
5960
dataset_name=dataset_name,
6061
)
6162

@@ -97,6 +98,7 @@ def load_ccle(
9798
response=response_data[measure].values,
9899
cell_line_ids=response_data[CELL_LINE_IDENTIFIER].values,
99100
drug_ids=response_data[DRUG_IDENTIFIER].values,
101+
tissues=response_data[TISSUE_IDENTIFIER].values,
100102
dataset_name=dataset_name,
101103
)
102104

@@ -122,6 +124,7 @@ def _load_toy(
122124
response=response_data[measure].values,
123125
cell_line_ids=response_data[CELL_LINE_IDENTIFIER].values,
124126
drug_ids=response_data[DRUG_IDENTIFIER].values,
127+
tissues=response_data[TISSUE_IDENTIFIER].values,
125128
dataset_name=dataset_name,
126129
)
127130

@@ -171,6 +174,7 @@ def _load_ctrpv(version: str, path_data: str = "data", measure: str = "LN_IC50_c
171174
response=response_data[measure].values,
172175
cell_line_ids=response_data[CELL_LINE_IDENTIFIER].values,
173176
drug_ids=response_data[DRUG_IDENTIFIER].values,
177+
tissues=response_data[TISSUE_IDENTIFIER].values,
174178
dataset_name=dataset_name,
175179
)
176180

@@ -199,16 +203,19 @@ def load_ctrpv2(path_data: str = "data", measure: str = "LN_IC50_curvecurator")
199203
return _load_ctrpv("2", path_data, measure)
200204

201205

202-
def load_custom(path_data: str | Path, measure: str = "response") -> DrugResponseDataset:
206+
def load_custom(
207+
path_data: str | Path, measure: str = "response", tissue_column: str | None = None
208+
) -> DrugResponseDataset:
203209
"""
204210
Load custom dataset.
205211
206212
:param path_data: Path to location of custom dataset
207213
:param measure: The name of the column containing the measure to predict, default = "response"
214+
:param tissue_column: The name of the column containing the tissue type. If None, no tissue information is loaded.
208215
209216
:return: DrugResponseDataset containing response, cell line IDs, and drug IDs
210217
"""
211-
return DrugResponseDataset.from_csv(path_data, measure=measure)
218+
return DrugResponseDataset.from_csv(path_data, measure=measure, tissue_column=tissue_column)
212219

213220

214221
AVAILABLE_DATASETS: dict[str, Callable] = {
@@ -224,7 +231,12 @@ def load_custom(path_data: str | Path, measure: str = "response") -> DrugRespons
224231

225232
@pipeline_function
226233
def load_dataset(
227-
dataset_name: str, path_data: str = "data", measure: str = "response", curve_curator: bool = False, cores: int = 1
234+
dataset_name: str,
235+
path_data: str = "data",
236+
measure: str = "response",
237+
curve_curator: bool = False,
238+
cores: int = 1,
239+
tissue_column: str | None = None,
228240
) -> DrugResponseDataset:
229241
"""
230242
Load a dataset based on the dataset name.
@@ -243,6 +255,8 @@ def load_dataset(
243255
which is expected to exist at <path_data>/<dataset_name>/<dataset_name>_raw.csv. The fitted dataset will
244256
be stored in the same folder, in a file called <dataset_name>.csv
245257
:param cores: Number of cores to use for CurveCurator fitting. Only used when curve_curator is True, default = 1
258+
:param tissue_column: The name of the column containing the tissue type. If None, no tissue information is loaded.
259+
This is only used when loading a custom dataset. Default = None.
246260
:return: A DrugResponseDataset containing response, cell line IDs, drug IDs, and dataset name.
247261
:raises FileNotFoundError: If the custom dataset or raw viability data could not be found at the given path.
248262
"""
@@ -263,5 +277,7 @@ def load_dataset(
263277
dataset_name=dataset_name,
264278
cores=cores,
265279
)
266-
return load_custom(Path(path_data) / dataset_name / f"{dataset_name}.csv", measure=measure)
280+
return load_custom(
281+
Path(path_data) / dataset_name / f"{dataset_name}.csv", measure=measure, tissue_column=tissue_column
282+
)
267283
raise FileNotFoundError(f"Custom dataset does not exist at given path: {input_file}")

0 commit comments

Comments
 (0)