11"""Tests for the baselines in the models module."""
22
3+ import pathlib
34import tempfile
45from typing import cast
56
67import numpy as np
8+ import pandas as pd
79import pytest
810from sklearn .linear_model import ElasticNet , Ridge
911
1012from drevalpy .datasets .dataset import DrugResponseDataset , FeatureDataset
11- from drevalpy .evaluation import evaluate , pearson
12- from drevalpy .experiment import cross_study_prediction
13+ from drevalpy .evaluation import evaluate
14+ from drevalpy .experiment import (
15+ consolidate_single_drug_model_predictions ,
16+ cross_study_prediction ,
17+ generate_data_saving_path ,
18+ get_datasets_from_cv_split ,
19+ train_and_predict ,
20+ )
1321from drevalpy .models import (
1422 MODEL_FACTORY ,
1523 NaiveCellLineMeanPredictor ,
1927)
2028from drevalpy .models .baselines .sklearn_models import SklearnModel
2129from drevalpy .models .drp_model import DRPModel
30+ from drevalpy .visualization .utils import evaluate_file
2231
2332
2433@pytest .mark .parametrize (
@@ -146,61 +155,67 @@ def test_single_drug_baselines(
146155 :param test_mode: either LPO or LCO
147156 :param cross_study_dataset: dataset
148157 """
149- drug_response = sample_dataset
150- drug_response .split_dataset (
158+ sample_dataset .split_dataset (
151159 n_cv_splits = 5 ,
152160 mode = test_mode ,
153161 )
154- assert drug_response .cv_splits is not None
155- split = drug_response .cv_splits [0 ]
156- train_dataset = split ["train" ]
157- val_dataset = split ["validation" ]
158-
162+ assert sample_dataset .cv_splits is not None
163+ split = sample_dataset .cv_splits [0 ]
159164 model = MODEL_FACTORY [model_name ]()
160- cell_line_input = model .load_cell_line_features (data_path = "../data" , dataset_name = "TOYv1" )
161- cell_lines_to_keep = cell_line_input .identifiers
162165
163- len_train_before = len (train_dataset )
164- len_pred_before = len (val_dataset )
165- train_dataset .reduce_to (cell_line_ids = cell_lines_to_keep , drug_ids = None )
166- val_dataset .reduce_to (cell_line_ids = cell_lines_to_keep , drug_ids = None )
167- print (f"Reduced training dataset from { len_train_before } to { len (train_dataset )} " )
168- print (f"Reduced val dataset from { len_pred_before } to { len (val_dataset )} " )
169-
170- all_unique_drugs = np .unique (train_dataset .drug_ids )
166+ # test what happens if a drug is only in the original dataset, not in the cross-study dataset
167+ exclusive_drugs = list (set (sample_dataset .drug_ids ).difference (set (cross_study_dataset .drug_ids )))
168+ all_unique_drugs = list (set (sample_dataset .drug_ids ).intersection (set (cross_study_dataset .drug_ids )))
169+ all_unique_drugs .sort ()
170+ exclusive_drugs .sort ()
171+ all_unique_drugs_arr = np .array (all_unique_drugs )
172+ exclusive_drugs_arr = np .array (exclusive_drugs )
171173 # randomly sample a drug to speed up testing
172174 np .random .seed (123 )
173- np .random .shuffle (all_unique_drugs )
174- random_drug = all_unique_drugs [:1 ]
175-
176- all_predictions = np .zeros_like (val_dataset .drug_ids , dtype = float )
175+ np .random .shuffle (all_unique_drugs_arr )
176+ np .random .shuffle (exclusive_drugs_arr )
177+ random_drugs = all_unique_drugs_arr [:1 ]
178+ random_drugs = np .concatenate ([random_drugs , exclusive_drugs_arr [:1 ]])
179+ # test what happens if the training and validation dataset is empty for a drug but the test set is not
180+ drug_to_remove = all_unique_drugs_arr [2 ]
181+ random_drugs = np .concatenate ([random_drugs , [drug_to_remove ]])
177182
178183 hpam_combi = model .get_hyperparameter_set ()[0 ]
184+ result_path = tempfile .TemporaryDirectory ()
179185 if model_name == "SingleDrugRandomForest" :
180186 hpam_combi ["n_estimators" ] = 2 # reduce test time
181187 hpam_combi ["max_depth" ] = 2 # reduce test time
182-
183- model .build_model (hpam_combi )
184- output_mask = train_dataset .drug_ids == random_drug
185- drug_train = train_dataset .copy ()
186- drug_train .mask (output_mask )
187- model .train (output = drug_train , cell_line_input = cell_line_input , drug_input = None )
188-
189- val_mask = val_dataset .drug_ids == random_drug
190- all_predictions [val_mask ] = model .predict (
191- drug_ids = random_drug ,
192- cell_line_ids = val_dataset .cell_line_ids [val_mask ],
193- cell_line_input = cell_line_input ,
194- )
195- # check whether predictions are constant
196- if np .all (all_predictions [val_mask ] == all_predictions [val_mask ][0 ]):
197- print ("Predictions are constant" )
198- else :
199- pcc_drug = pearson (val_dataset .response [val_mask ], all_predictions [val_mask ])
200- print (f"{ test_mode } : Performance of { model_name } for drug { random_drug } : PCC = { pcc_drug } " )
201- assert pcc_drug >= - 1.0
202- with tempfile .TemporaryDirectory () as temp_dir :
203- print (f"Running cross-study prediction for { model_name } " )
188+ for random_drug in random_drugs :
189+ predictions_path = generate_data_saving_path (
190+ model_name = model_name ,
191+ drug_id = str (random_drug ),
192+ result_path = result_path .name ,
193+ suffix = "predictions" ,
194+ )
195+ prediction_file = pathlib .Path (predictions_path , "predictions_split_0.csv" )
196+ (
197+ train_dataset ,
198+ validation_dataset ,
199+ early_stopping_dataset ,
200+ test_dataset ,
201+ ) = get_datasets_from_cv_split (split , MODEL_FACTORY [model_name ], model_name , random_drug )
202+ train_dataset .add_rows (validation_dataset )
203+ if random_drug == drug_to_remove :
204+ reduce_to_drugs = np .array (list (set (train_dataset .drug_ids ) - {random_drug }))
205+ train_dataset .reduce_to (cell_line_ids = None , drug_ids = reduce_to_drugs )
206+ train_dataset .shuffle (random_state = 42 )
207+ test_dataset = train_and_predict (
208+ model = model ,
209+ hpams = hpam_combi ,
210+ path_data = "../data" ,
211+ train_dataset = train_dataset ,
212+ prediction_dataset = test_dataset ,
213+ early_stopping_dataset = None ,
214+ response_transformation = None ,
215+ model_checkpoint_dir = "TEMPORARY" ,
216+ )
217+ cross_study_dataset .remove_nan_responses ()
218+ parent_dir = str (pathlib .Path (predictions_path ).parent )
204219 cross_study_prediction (
205220 dataset = cross_study_dataset ,
206221 model = model ,
@@ -209,10 +224,38 @@ def test_single_drug_baselines(
209224 path_data = "../data" ,
210225 early_stopping_dataset = None ,
211226 response_transformation = None ,
212- path_out = temp_dir ,
227+ path_out = parent_dir ,
213228 split_index = 0 ,
214- single_drug_id = str (random_drug [ 0 ] ),
229+ single_drug_id = str (random_drug ),
215230 )
231+ test_dataset .to_csv (prediction_file )
232+ consolidate_single_drug_model_predictions (
233+ models = [MODEL_FACTORY [model_name ]],
234+ n_cv_splits = 1 ,
235+ results_path = result_path .name ,
236+ cross_study_datasets = [cross_study_dataset .dataset_name ],
237+ randomization_mode = None ,
238+ n_trials_robustness = 0 ,
239+ out_path = result_path .name ,
240+ )
241+ # get cross-study predictions and assert that each drug-cell line combination only occurs once
242+ cross_study_predictions = pd .read_csv (
243+ pathlib .Path (result_path .name , model_name , "cross_study" , "cross_study_TOYv2_split_0.csv" )
244+ )
245+ assert len (cross_study_predictions ) == len (cross_study_predictions .drop_duplicates (["drug_id" , "cell_line_id" ]))
246+ predictions_file = pathlib .Path (result_path .name , model_name , "predictions" , "predictions_split_0.csv" )
247+ cross_study_file = pathlib .Path (result_path .name , model_name , "cross_study" , "cross_study_TOYv2_split_0.csv" )
248+ for file in [predictions_file , cross_study_file ]:
249+ (
250+ overall_eval ,
251+ eval_results_per_drug ,
252+ eval_results_per_cl ,
253+ t_vs_p ,
254+ model_name ,
255+ ) = evaluate_file (pred_file = file , test_mode = test_mode , model_name = model_name )
256+ assert len (overall_eval ) == 1
257+ print (f"Performance of { model_name } : PCC = { overall_eval ['Pearson' ][0 ]} " )
258+ assert overall_eval ["Pearson" ][0 ] >= - 1.0
216259
217260
218261def _call_naive_predictor (
0 commit comments