Skip to content

Commit 5b1ce9e

Browse files
Merge pull request #160 from daisybio/development
v1.2.5
2 parents 5d4876a + d7bb0fc commit 5b1ce9e

File tree

7 files changed

+198
-145
lines changed

7 files changed

+198
-145
lines changed

docs/conf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@
5656
# the built documents.
5757
#
5858
# The short X.Y version.
59-
version = "1.2.4"
59+
version = "1.2.5"
6060
# The full version, including alpha/beta/rc tags.
61-
release = "1.2.4"
61+
release = "1.2.5"
6262

6363
# The language for content autogenerated by Sphinx. Refer to documentation
6464
# for a list of supported languages.

drevalpy/evaluation.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,16 @@ def evaluate(dataset: DrugResponseDataset, metric: list[str] | str):
255255
)
256256
)
257257
else:
258-
results[m] = float(AVAILABLE_METRICS[m](y_pred=predictions, y_true=response))
258+
# check whether the predictions contain NaNs
259+
if np.any(np.isnan(predictions)):
260+
# if there are only NaNs in the predictions, the metric is NaN
261+
if np.all(np.isnan(predictions)):
262+
results[m] = float(np.nan)
263+
else:
264+
# remove the rows with NaNs in the predictions and response
265+
mask = ~np.isnan(predictions)
266+
results[m] = float(AVAILABLE_METRICS[m](y_pred=predictions[mask], y_true=response[mask]))
267+
else:
268+
results[m] = float(AVAILABLE_METRICS[m](y_pred=predictions, y_true=response))
259269

260270
return results

drevalpy/visualization/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def parse_results(path_to_results: str, dataset: str) -> tuple[pd.DataFrame, pd.
111111
@pipeline_function
112112
def evaluate_file(
113113
pred_file: pathlib.Path, test_mode: str, model_name: str
114-
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, str]:
114+
) -> tuple[pd.DataFrame, pd.DataFrame | None, pd.DataFrame | None, pd.DataFrame, str]:
115115
"""
116116
Evaluate the predictions from the final models.
117117

poetry.lock

Lines changed: 88 additions & 88 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "drevalpy"
3-
version = "1.2.4"
3+
version = "1.2.5"
44
description = "Drug response evaluation of cancer cell line drug response models in a fair setting"
55
authors = ["DrEvalPy development team"]
66
license = "GPL-3.0"

requirements.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
aiohappyeyeballs==2.4.8 ; python_version >= "3.11" and python_version < "3.13"
1+
aiohappyeyeballs==2.5.0 ; python_version >= "3.11" and python_version < "3.13"
22
aiohttp==3.11.13 ; python_version >= "3.11" and python_version < "3.13"
33
aiosignal==1.3.2 ; python_version >= "3.11" and python_version < "3.13"
44
anyio==4.8.0 ; python_version >= "3.11" and python_version < "3.13"
@@ -38,18 +38,18 @@ jaraco-classes==3.4.0 ; python_version >= "3.11" and python_version < "3.13"
3838
jaraco-context==6.0.1 ; python_version >= "3.11" and python_version < "3.13"
3939
jaraco-functools==4.1.0 ; python_version >= "3.11" and python_version < "3.13"
4040
jeepney==0.9.0 ; python_version >= "3.11" and python_version < "3.13" and sys_platform == "linux"
41-
jinja2==3.1.5 ; python_version >= "3.11" and python_version < "3.13"
41+
jinja2==3.1.6 ; python_version >= "3.11" and python_version < "3.13"
4242
joblib==1.4.2 ; python_version >= "3.11" and python_version < "3.13"
4343
keyring==25.6.0 ; python_version >= "3.11" and python_version < "3.13"
4444
kiwisolver==1.4.8 ; python_version >= "3.11" and python_version < "3.13"
45-
lightning-utilities==0.13.1 ; python_version >= "3.11" and python_version < "3.13"
45+
lightning-utilities==0.14.0 ; python_version >= "3.11" and python_version < "3.13"
4646
markupsafe==3.0.2 ; python_version >= "3.11" and python_version < "3.13"
4747
matplotlib==3.10.1 ; python_version >= "3.11" and python_version < "3.13"
4848
more-itertools==10.6.0 ; python_version >= "3.11" and python_version < "3.13"
4949
mpmath==1.3.0 ; python_version >= "3.11" and python_version < "3.13"
5050
msgpack==1.1.0 ; python_version >= "3.11" and python_version < "3.13"
5151
multidict==6.1.0 ; python_version >= "3.11" and python_version < "3.13"
52-
narwhals==1.29.0 ; python_version >= "3.11" and python_version < "3.13"
52+
narwhals==1.29.1 ; python_version >= "3.11" and python_version < "3.13"
5353
networkx==3.4.2 ; python_version >= "3.11" and python_version < "3.13"
5454
numpy==1.26.4 ; python_version >= "3.11" and python_version < "3.13"
5555
nvidia-cublas-cu12==12.4.5.8 ; python_version >= "3.11" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
@@ -117,7 +117,7 @@ trove-classifiers==2025.3.3.18 ; python_version >= "3.11" and python_version < "
117117
typing-extensions==4.12.2 ; python_version >= "3.11" and python_version < "3.13"
118118
tzdata==2025.1 ; python_version >= "3.11" and python_version < "3.13"
119119
urllib3==2.3.0 ; python_version >= "3.11" and python_version < "3.13"
120-
virtualenv==20.29.2 ; python_version >= "3.11" and python_version < "3.13"
120+
virtualenv==20.29.3 ; python_version >= "3.11" and python_version < "3.13"
121121
xarray==2025.1.2 ; python_version >= "3.11" and python_version < "3.13"
122122
xattr==1.1.4 ; python_version >= "3.11" and python_version < "3.13" and sys_platform == "darwin"
123123
xyzservices==2025.1.0 ; python_version >= "3.11" and python_version < "3.13"

tests/individual_models/test_baselines.py

Lines changed: 90 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
11
"""Tests for the baselines in the models module."""
22

3+
import pathlib
34
import tempfile
45
from typing import cast
56

67
import numpy as np
8+
import pandas as pd
79
import pytest
810
from sklearn.linear_model import ElasticNet, Ridge
911

1012
from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset
11-
from drevalpy.evaluation import evaluate, pearson
12-
from drevalpy.experiment import cross_study_prediction
13+
from drevalpy.evaluation import evaluate
14+
from drevalpy.experiment import (
15+
consolidate_single_drug_model_predictions,
16+
cross_study_prediction,
17+
generate_data_saving_path,
18+
get_datasets_from_cv_split,
19+
train_and_predict,
20+
)
1321
from drevalpy.models import (
1422
MODEL_FACTORY,
1523
NaiveCellLineMeanPredictor,
@@ -19,6 +27,7 @@
1927
)
2028
from drevalpy.models.baselines.sklearn_models import SklearnModel
2129
from drevalpy.models.drp_model import DRPModel
30+
from drevalpy.visualization.utils import evaluate_file
2231

2332

2433
@pytest.mark.parametrize(
@@ -146,61 +155,67 @@ def test_single_drug_baselines(
146155
:param test_mode: either LPO or LCO
147156
:param cross_study_dataset: dataset
148157
"""
149-
drug_response = sample_dataset
150-
drug_response.split_dataset(
158+
sample_dataset.split_dataset(
151159
n_cv_splits=5,
152160
mode=test_mode,
153161
)
154-
assert drug_response.cv_splits is not None
155-
split = drug_response.cv_splits[0]
156-
train_dataset = split["train"]
157-
val_dataset = split["validation"]
158-
162+
assert sample_dataset.cv_splits is not None
163+
split = sample_dataset.cv_splits[0]
159164
model = MODEL_FACTORY[model_name]()
160-
cell_line_input = model.load_cell_line_features(data_path="../data", dataset_name="TOYv1")
161-
cell_lines_to_keep = cell_line_input.identifiers
162165

163-
len_train_before = len(train_dataset)
164-
len_pred_before = len(val_dataset)
165-
train_dataset.reduce_to(cell_line_ids=cell_lines_to_keep, drug_ids=None)
166-
val_dataset.reduce_to(cell_line_ids=cell_lines_to_keep, drug_ids=None)
167-
print(f"Reduced training dataset from {len_train_before} to {len(train_dataset)}")
168-
print(f"Reduced val dataset from {len_pred_before} to {len(val_dataset)}")
169-
170-
all_unique_drugs = np.unique(train_dataset.drug_ids)
166+
# test what happens if a drug is only in the original dataset, not in the cross-study dataset
167+
exclusive_drugs = list(set(sample_dataset.drug_ids).difference(set(cross_study_dataset.drug_ids)))
168+
all_unique_drugs = list(set(sample_dataset.drug_ids).intersection(set(cross_study_dataset.drug_ids)))
169+
all_unique_drugs.sort()
170+
exclusive_drugs.sort()
171+
all_unique_drugs_arr = np.array(all_unique_drugs)
172+
exclusive_drugs_arr = np.array(exclusive_drugs)
171173
# randomly sample a drug to speed up testing
172174
np.random.seed(123)
173-
np.random.shuffle(all_unique_drugs)
174-
random_drug = all_unique_drugs[:1]
175-
176-
all_predictions = np.zeros_like(val_dataset.drug_ids, dtype=float)
175+
np.random.shuffle(all_unique_drugs_arr)
176+
np.random.shuffle(exclusive_drugs_arr)
177+
random_drugs = all_unique_drugs_arr[:1]
178+
random_drugs = np.concatenate([random_drugs, exclusive_drugs_arr[:1]])
179+
# test what happens if the training and validation dataset is empty for a drug but the test set is not
180+
drug_to_remove = all_unique_drugs_arr[2]
181+
random_drugs = np.concatenate([random_drugs, [drug_to_remove]])
177182

178183
hpam_combi = model.get_hyperparameter_set()[0]
184+
result_path = tempfile.TemporaryDirectory()
179185
if model_name == "SingleDrugRandomForest":
180186
hpam_combi["n_estimators"] = 2 # reduce test time
181187
hpam_combi["max_depth"] = 2 # reduce test time
182-
183-
model.build_model(hpam_combi)
184-
output_mask = train_dataset.drug_ids == random_drug
185-
drug_train = train_dataset.copy()
186-
drug_train.mask(output_mask)
187-
model.train(output=drug_train, cell_line_input=cell_line_input, drug_input=None)
188-
189-
val_mask = val_dataset.drug_ids == random_drug
190-
all_predictions[val_mask] = model.predict(
191-
drug_ids=random_drug,
192-
cell_line_ids=val_dataset.cell_line_ids[val_mask],
193-
cell_line_input=cell_line_input,
194-
)
195-
# check whether predictions are constant
196-
if np.all(all_predictions[val_mask] == all_predictions[val_mask][0]):
197-
print("Predictions are constant")
198-
else:
199-
pcc_drug = pearson(val_dataset.response[val_mask], all_predictions[val_mask])
200-
print(f"{test_mode}: Performance of {model_name} for drug {random_drug}: PCC = {pcc_drug}")
201-
assert pcc_drug >= -1.0
202-
with tempfile.TemporaryDirectory() as temp_dir:
203-
print(f"Running cross-study prediction for {model_name}")
188+
for random_drug in random_drugs:
189+
predictions_path = generate_data_saving_path(
190+
model_name=model_name,
191+
drug_id=str(random_drug),
192+
result_path=result_path.name,
193+
suffix="predictions",
194+
)
195+
prediction_file = pathlib.Path(predictions_path, "predictions_split_0.csv")
196+
(
197+
train_dataset,
198+
validation_dataset,
199+
early_stopping_dataset,
200+
test_dataset,
201+
) = get_datasets_from_cv_split(split, MODEL_FACTORY[model_name], model_name, random_drug)
202+
train_dataset.add_rows(validation_dataset)
203+
if random_drug == drug_to_remove:
204+
reduce_to_drugs = np.array(list(set(train_dataset.drug_ids) - {random_drug}))
205+
train_dataset.reduce_to(cell_line_ids=None, drug_ids=reduce_to_drugs)
206+
train_dataset.shuffle(random_state=42)
207+
test_dataset = train_and_predict(
208+
model=model,
209+
hpams=hpam_combi,
210+
path_data="../data",
211+
train_dataset=train_dataset,
212+
prediction_dataset=test_dataset,
213+
early_stopping_dataset=None,
214+
response_transformation=None,
215+
model_checkpoint_dir="TEMPORARY",
216+
)
217+
cross_study_dataset.remove_nan_responses()
218+
parent_dir = str(pathlib.Path(predictions_path).parent)
204219
cross_study_prediction(
205220
dataset=cross_study_dataset,
206221
model=model,
@@ -209,10 +224,38 @@ def test_single_drug_baselines(
209224
path_data="../data",
210225
early_stopping_dataset=None,
211226
response_transformation=None,
212-
path_out=temp_dir,
227+
path_out=parent_dir,
213228
split_index=0,
214-
single_drug_id=str(random_drug[0]),
229+
single_drug_id=str(random_drug),
215230
)
231+
test_dataset.to_csv(prediction_file)
232+
consolidate_single_drug_model_predictions(
233+
models=[MODEL_FACTORY[model_name]],
234+
n_cv_splits=1,
235+
results_path=result_path.name,
236+
cross_study_datasets=[cross_study_dataset.dataset_name],
237+
randomization_mode=None,
238+
n_trials_robustness=0,
239+
out_path=result_path.name,
240+
)
241+
# get cross-study predictions and assert that each drug-cell line combination only occurs once
242+
cross_study_predictions = pd.read_csv(
243+
pathlib.Path(result_path.name, model_name, "cross_study", "cross_study_TOYv2_split_0.csv")
244+
)
245+
assert len(cross_study_predictions) == len(cross_study_predictions.drop_duplicates(["drug_id", "cell_line_id"]))
246+
predictions_file = pathlib.Path(result_path.name, model_name, "predictions", "predictions_split_0.csv")
247+
cross_study_file = pathlib.Path(result_path.name, model_name, "cross_study", "cross_study_TOYv2_split_0.csv")
248+
for file in [predictions_file, cross_study_file]:
249+
(
250+
overall_eval,
251+
eval_results_per_drug,
252+
eval_results_per_cl,
253+
t_vs_p,
254+
model_name,
255+
) = evaluate_file(pred_file=file, test_mode=test_mode, model_name=model_name)
256+
assert len(overall_eval) == 1
257+
print(f"Performance of {model_name}: PCC = {overall_eval['Pearson'][0]}")
258+
assert overall_eval["Pearson"][0] >= -1.0
216259

217260

218261
def _call_naive_predictor(

0 commit comments

Comments
 (0)