Skip to content

Commit 7680579

Browse files
Merge pull request #132 from daisybio/development
New version
2 parents 431d252 + 342860b commit 7680579

25 files changed

+1448
-1027
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ data/GDSC1
66
data/GDSC2
77
data/CCLE
88
data/Toy_Data
9+
data/CTRPv1
10+
data/CTRPv2
911

1012
# Byte-compiled / optimized / DLL files
1113
__pycache__/

create_report.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,8 +267,10 @@ def draw_per_grouping_algorithm_plots(
267267
if __name__ == "__main__":
268268
parser = argparse.ArgumentParser(description="Generate reports from evaluation results")
269269
parser.add_argument("--run_id", required=True, help="Run ID for the current execution")
270+
parser.add_argument("--dataset", required=True, help="Dataset name for which to render the result file")
270271
args = parser.parse_args()
271272
run_id = args.run_id
273+
dataset = args.dataset
272274

273275
# assert that the run_id folder exists
274276
if not os.path.exists(f"results/{run_id}"):
@@ -280,7 +282,7 @@ def draw_per_grouping_algorithm_plots(
280282
evaluation_results_per_drug,
281283
evaluation_results_per_cell_line,
282284
true_vs_pred,
283-
) = parse_results(path_to_results=f"results/{run_id}")
285+
) = parse_results(path_to_results=f"results/{run_id}", dataset=dataset)
284286

285287
# part of pipeline: EVALUATE_FINAL, COLLECT_RESULTS
286288
(

docs/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
sphinx-autobuild==2024.10.3 ; python_version >= "3.11" and python_version < "3.13"
2-
sphinx-autodoc-typehints==3.0.1 ; python_version >= "3.11" and python_version < "3.13"
2+
sphinx-autodoc-typehints==3.1.0 ; python_version >= "3.11" and python_version < "3.13"
33
sphinx-click==6.0.0 ; python_version >= "3.11" and python_version < "3.13"
44
sphinx-rtd-theme==3.0.2 ; python_version >= "3.11" and python_version < "3.13"

drevalpy/datasets/dataset.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def from_csv(
5555
- response: the drug response values as floating point values
5656
- cell_line_ids: a string identifier for cell lines
5757
- drug_ids: a string identifier for drugs
58-
- predictions: an optional column containing a predicted value TODO what exactly?
58+
- predictions: an optional column containing drug response predictions
5959
6060
:param input_file: Path to the csv file containing the data to be loaded
6161
:param dataset_name: Optional name to associate the dataset with, default = "unknown"
@@ -64,6 +64,8 @@ def from_csv(
6464
:returns: DrugResponseDataset object containing data from provided csv file.
6565
"""
6666
data = pd.read_csv(input_file)
67+
data["drug_id"] = data["drug_id"].astype(str)
68+
6769
if "predictions" in data.columns:
6870
predictions = data["predictions"].values
6971
else:
@@ -152,9 +154,9 @@ def __init__(
152154
"""
153155
super().__init__()
154156
if len(response) != len(cell_line_ids):
155-
raise AssertionError("Response and cell_line_ids have different lengths.")
157+
raise AssertionError("Response and cell line identifiers have different lengths.")
156158
if len(response) != len(drug_ids):
157-
raise AssertionError("Response and drug_ids have different lengths.")
159+
raise AssertionError("Response and drug identifiers have different lengths.")
158160
if predictions is not None and len(response) != len(predictions):
159161
raise AssertionError("Response and predictions have different lengths.")
160162
self._response = response

drevalpy/datasets/loader.py

Lines changed: 67 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
from ..pipeline_function import pipeline_function
1010
from .curvecurator import fit_curves
1111
from .dataset import DrugResponseDataset
12-
from .utils import download_dataset
12+
from .utils import CELL_LINE_IDENTIFIER, DRUG_IDENTIFIER, download_dataset
1313

1414

1515
def load_gdsc1(
1616
path_data: str = "data",
17-
measure: str = "LN_IC50",
18-
file_name: str = "response_GDSC1.csv",
17+
measure: str = "LN_IC50_curvecurator",
18+
file_name: str = "GDSC1.csv",
1919
dataset_name: str = "GDSC1",
2020
) -> DrugResponseDataset:
2121
"""
@@ -32,18 +32,18 @@ def load_gdsc1(
3232
if not os.path.exists(path):
3333
download_dataset(dataset_name, path_data, redownload=True)
3434

35-
response_data = pd.read_csv(path)
36-
response_data["DRUG_NAME"] = response_data["DRUG_NAME"].str.replace(",", "")
35+
response_data = pd.read_csv(path, dtype={"pubchem_id": str})
36+
response_data[DRUG_IDENTIFIER] = response_data[DRUG_IDENTIFIER].str.replace(",", "")
3737

3838
return DrugResponseDataset(
3939
response=response_data[measure].values,
40-
cell_line_ids=response_data["CELL_LINE_NAME"].values,
41-
drug_ids=response_data["DRUG_NAME"].values,
40+
cell_line_ids=response_data[CELL_LINE_IDENTIFIER].values,
41+
drug_ids=response_data[DRUG_IDENTIFIER].values,
4242
dataset_name=dataset_name,
4343
)
4444

4545

46-
def load_gdsc2(path_data: str = "data", measure: str = "LN_IC50", file_name: str = "response_GDSC2.csv"):
46+
def load_gdsc2(path_data: str = "data", measure: str = "LN_IC50_curvecurator", file_name: str = "GDSC2.csv"):
4747
"""
4848
Loads the GDSC2 dataset.
4949
@@ -57,7 +57,7 @@ def load_gdsc2(path_data: str = "data", measure: str = "LN_IC50", file_name: str
5757

5858

5959
def load_ccle(
60-
path_data: str = "data", measure: str = "LN_IC50", file_name: str = "response_CCLE.csv"
60+
path_data: str = "data", measure: str = "LN_IC50_curvecurator", file_name: str = "CCLE.csv"
6161
) -> DrugResponseDataset:
6262
"""
6363
Loads the CCLE dataset.
@@ -73,18 +73,18 @@ def load_ccle(
7373
if not os.path.exists(path):
7474
download_dataset(dataset_name, path_data, redownload=True)
7575

76-
response_data = pd.read_csv(path)
77-
response_data["DRUG_NAME"] = response_data["DRUG_NAME"].str.replace(",", "")
76+
response_data = pd.read_csv(path, dtype={"pubchem_id": str})
77+
response_data[DRUG_IDENTIFIER] = response_data[DRUG_IDENTIFIER].str.replace(",", "")
7878

7979
return DrugResponseDataset(
8080
response=response_data[measure].values,
81-
cell_line_ids=response_data["CELL_LINE_NAME"].values,
82-
drug_ids=response_data["DRUG_NAME"].values,
81+
cell_line_ids=response_data[CELL_LINE_IDENTIFIER].values,
82+
drug_ids=response_data[DRUG_IDENTIFIER].values,
8383
dataset_name=dataset_name,
8484
)
8585

8686

87-
def load_toy(path_data: str = "data", measure: str = "response") -> DrugResponseDataset:
87+
def load_toy(path_data: str = "data", measure: str = "LN_IC50_curvecurator") -> DrugResponseDataset:
8888
"""
8989
Loads small Toy dataset, subsampled from GDSC1.
9090
@@ -94,20 +94,67 @@ def load_toy(path_data: str = "data", measure: str = "response") -> DrugResponse
9494
:return: DrugResponseDataset containing response, cell line IDs, and drug IDs.
9595
"""
9696
dataset_name = "Toy_Data"
97-
measure = "response" # overwrite this explicitly to avoid problems, should be changed in the future
9897
path = os.path.join(path_data, dataset_name, "toy_data.csv")
9998
if not os.path.exists(path):
10099
download_dataset(dataset_name, path_data, redownload=True)
101-
response_data = pd.read_csv(path)
100+
response_data = pd.read_csv(path, dtype={"pubchem_id": str})
102101

103102
return DrugResponseDataset(
104103
response=response_data[measure].values,
105-
cell_line_ids=response_data["cell_line_id"].values,
106-
drug_ids=response_data["drug_id"].values,
104+
cell_line_ids=response_data[CELL_LINE_IDENTIFIER].values,
105+
drug_ids=response_data[DRUG_IDENTIFIER].values,
107106
dataset_name=dataset_name,
108107
)
109108

110109

110+
def _load_ctrpv(version: str, path_data: str = "data", measure: str = "LN_IC50_curvecurator") -> DrugResponseDataset:
111+
"""
112+
Load CTRPv1 dataset.
113+
114+
:param version: The version of the CTRP dataset to load.
115+
:param path_data: Path to location of CTRPv1 dataset
116+
:param measure: The name of the column containing the measure to predict, default = "response"
117+
118+
:return: DrugResponseDataset containing response, cell line IDs, and drug IDs
119+
"""
120+
dataset_name = "CTRPv" + version
121+
path = os.path.join(path_data, dataset_name, f"{dataset_name}.csv")
122+
if not os.path.exists(path):
123+
download_dataset(dataset_name, path_data, redownload=True)
124+
response_data = pd.read_csv(path, dtype={"pubchem_id": str})
125+
126+
return DrugResponseDataset(
127+
response=response_data[measure].values,
128+
cell_line_ids=response_data[CELL_LINE_IDENTIFIER].values,
129+
drug_ids=response_data[DRUG_IDENTIFIER].values,
130+
dataset_name=dataset_name,
131+
)
132+
133+
134+
def load_ctrpv1(path_data: str = "data", measure: str = "LN_IC50_curvecurator") -> DrugResponseDataset:
135+
"""
136+
Load CTRPv2 dataset.
137+
138+
:param path_data: Path to location of CTRPv2 dataset
139+
:param measure: The name of the column containing the measure to predict, default = "LN_IC50_curvecurator"
140+
141+
:return: DrugResponseDataset containing response, cell line IDs, and drug IDs
142+
"""
143+
return _load_ctrpv("1", path_data, measure)
144+
145+
146+
def load_ctrpv2(path_data: str = "data", measure: str = "LN_IC50_curvecurator") -> DrugResponseDataset:
147+
"""
148+
Load CTRPv2 dataset.
149+
150+
:param path_data: Path to location of CTRPv2 dataset
151+
:param measure: The name of the column containing the measure to predict, default: LN_IC50_curvecurator
152+
153+
:return: DrugResponseDataset containing response, cell line IDs, and drug IDs
154+
"""
155+
return _load_ctrpv("2", path_data, measure)
156+
157+
111158
def load_custom(path_data: str | Path, measure: str = "response") -> DrugResponseDataset:
112159
"""
113160
Load custom dataset.
@@ -125,6 +172,8 @@ def load_custom(path_data: str | Path, measure: str = "response") -> DrugRespons
125172
"GDSC2": load_gdsc2,
126173
"CCLE": load_ccle,
127174
"Toy_Data": load_toy,
175+
"CTRPv1": load_ctrpv1,
176+
"CTRPv2": load_ctrpv2,
128177
}
129178

130179

drevalpy/datasets/utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
import numpy as np
1010
import requests
1111

12+
DRUG_IDENTIFIER = "pubchem_id"
13+
CELL_LINE_IDENTIFIER = "cell_line_name"
14+
1215

1316
def download_dataset(
1417
dataset_name: str,
@@ -26,18 +29,18 @@ def download_dataset(
2629
file_name = f"{dataset_name}.zip"
2730
file_path = Path(data_path) / file_name
2831
extracted_folder_path = file_path.with_suffix("")
29-
32+
timeout = 120
3033
# Check if the extracted data exists and skip download if not redownloading
3134
if extracted_folder_path.exists() and not redownload:
3235
print(f"{dataset_name} is already extracted, skipping download.")
3336
else:
3437
url = "https://zenodo.org/doi/10.5281/zenodo.12633909"
3538
# Fetch the latest record
36-
response = requests.get(url, timeout=60)
39+
response = requests.get(url, timeout=timeout)
3740
if response.status_code != 200:
3841
raise requests.exceptions.HTTPError(f"Error fetching record: {response.status_code}")
3942
latest_url = response.links["linkset"]["url"]
40-
response = requests.get(latest_url, timeout=60)
43+
response = requests.get(latest_url, timeout=timeout)
4144
if response.status_code != 200:
4245
raise requests.exceptions.HTTPError(f"Error fetching record: {response.status_code}")
4346
data = response.json()
@@ -50,7 +53,7 @@ def download_dataset(
5053
file_url = name_to_url[file_name]
5154
# Download the file
5255
print(f"Downloading {dataset_name} from {file_url}...")
53-
response = requests.get(file_url, timeout=60)
56+
response = requests.get(file_url, timeout=timeout)
5457
if response.status_code != 200:
5558
raise requests.exceptions.HTTPError(f"Error downloading file {dataset_name}: " f"{response.status_code}")
5659

@@ -61,7 +64,7 @@ def download_dataset(
6164
with zipfile.ZipFile(file_path, "r") as z:
6265
for member in z.infolist():
6366
if not member.filename.startswith("__MACOSX/"):
64-
z.extract(member, os.path.join(data_path, dataset_name))
67+
z.extract(member, os.path.join(data_path))
6568
file_path.unlink() # Remove zip file after extraction
6669

6770
print(f"{dataset_name} data downloaded and extracted to {data_path}")

drevalpy/models/SimpleNeuralNetwork/hyperparameters.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ SimpleNeuralNetwork:
2222
- 128
2323
- 64
2424
- 16
25+
max_epochs:
26+
- 100
2527

2628
MultiOmicsNeuralNetwork:
2729
dropout_prob:
@@ -44,3 +46,5 @@ MultiOmicsNeuralNetwork:
4446
- 32
4547
methylation_pca_components:
4648
- 100
49+
max_epochs:
50+
- 100

drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ def train(
110110
cell_line_views=self.cell_line_views,
111111
drug_views=self.drug_views,
112112
output_earlystopping=output_earlystopping,
113+
trainer_params={
114+
"max_epochs": self.hyperparameters.get("max_epochs", 100),
115+
"progress_bar_refresh_rate": 500,
116+
},
113117
batch_size=16,
114118
patience=5,
115119
num_workers=1,

drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,10 @@ def train(
105105
cell_line_views=self.cell_line_views,
106106
drug_views=self.drug_views,
107107
output_earlystopping=output_earlystopping,
108+
trainer_params={
109+
"max_epochs": self.hyperparameters.get("max_epochs", 100),
110+
"progress_bar_refresh_rate": 500,
111+
},
108112
batch_size=16,
109113
patience=5,
110114
num_workers=1 if platform.system() == "Windows" else 8,

drevalpy/models/SimpleNeuralNetwork/utils.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,16 +180,15 @@ def fit(
180180
:param model_checkpoint_dir: directory to save the model checkpoints
181181
:raises ValueError: if drug_input is missing
182182
"""
183-
if drug_input is None:
184-
raise ValueError(
185-
"Drug input (fingerprints) are required for SimpleNeuralNetwork and " "MultiOMICsNeuralNetwork."
186-
)
187-
188183
if trainer_params is None:
189184
trainer_params = {
185+
"max_epochs": 100,
190186
"progress_bar_refresh_rate": 500,
191-
"max_epochs": 70,
192187
}
188+
if drug_input is None:
189+
raise ValueError(
190+
"Drug input (fingerprints) are required for SimpleNeuralNetwork and " "MultiOMICsNeuralNetwork."
191+
)
193192

194193
train_dataset = RegressionDataset(
195194
output=output_train,

0 commit comments

Comments
 (0)