44import os
55import shutil
66import warnings
7- from typing import Optional
7+ from typing import Any , Optional
88
99import numpy as np
1010import pandas as pd
1616from .datasets .dataset import DrugResponseDataset , FeatureDataset
1717from .evaluation import evaluate , get_mode
1818from .models import MODEL_FACTORY , MULTI_DRUG_MODEL_FACTORY , SINGLE_DRUG_MODEL_FACTORY
19- from .models .drp_model import DRPModel , SingleDrugModel
19+ from .models .drp_model import DRPModel
2020from .pipeline_function import pipeline_function
2121
2222
@@ -82,14 +82,14 @@ def drug_response_experiment(
8282 :param test_mode: test mode one of "LPO", "LCO", "LDO" (leave-pair-out, leave-cell-line-out, leave-drug-out)
8383 :param overwrite: whether to overwrite existing results
8484 :param path_data: path to the data directory, usually data/
85+ :raises ValueError: if no cv splits are found
8586 """
8687 if baselines is None :
8788 baselines = []
8889 cross_study_datasets = cross_study_datasets or []
8990 result_path = os .path .join (path_out , run_id , test_mode )
9091 split_path = os .path .join (result_path , "splits" )
9192 result_folder_exists = os .path .exists (result_path )
92- randomization_test_views = []
9393 if result_folder_exists and overwrite :
9494 # if results exists, delete them if overwrite is True
9595 print (f"Overwriting existing results at { result_path } " )
@@ -146,6 +146,9 @@ def drug_response_experiment(
146146
147147 model_hpam_set = model_class .get_hyperparameter_set ()
148148
149+ if response_data .cv_splits is None :
150+ raise ValueError ("No cv splits found." )
151+
149152 for split_index , split in enumerate (response_data .cv_splits ):
150153 print (f"################# FOLD { split_index + 1 } /{ len (response_data .cv_splits )} " f"#################" )
151154
@@ -233,7 +236,7 @@ def drug_response_experiment(
233236 best_hpams = json .load (f )
234237 if not is_baseline :
235238 if randomization_mode is not None :
236- print (f"Randomization tests for { model_class .model_name } " )
239+ print (f"Randomization tests for { model_class .get_model_name () } " )
237240 # if this line changes, it also needs to be changed in pipeline:
238241 # randomization_split.py
239242 randomization_test_views = get_randomization_test_views (
@@ -253,7 +256,7 @@ def drug_response_experiment(
253256 response_transformation = response_transformation ,
254257 )
255258 if n_trials_robustness > 0 :
256- print (f"Robustness test for { model_class .model_name } " )
259+ print (f"Robustness test for { model_class .get_model_name () } " )
257260 robustness_test (
258261 n_trials = n_trials_robustness ,
259262 model = model ,
@@ -289,7 +292,7 @@ def consolidate_single_drug_model_predictions(
289292 out_path : str = "" ,
290293) -> None :
291294 """
292- Consolidate SingleDrugModel predictions into a single file.
295+ Consolidate single drug model predictions into a single file.
293296
294297 :param models: list of model classes to compare, e.g., [SimpleNeuralNetwork, RandomForest]
295298 :param n_cv_splits: number of cross-validation splits, e.g., 5
@@ -301,10 +304,11 @@ def consolidate_single_drug_model_predictions(
301304 will be stored in the work directory.
302305 """
303306 for model in models :
304- if model .model_name in SINGLE_DRUG_MODEL_FACTORY :
305- model_instance = MODEL_FACTORY [model .model_name ]()
306- model_path = os .path .join (results_path , str (model .model_name ))
307- out_path = os .path .join (out_path , str (model .model_name ))
307+ if model .get_model_name () in SINGLE_DRUG_MODEL_FACTORY :
308+
309+ model_instance = MODEL_FACTORY [model .get_model_name ()]()
310+ model_path = os .path .join (results_path , model .get_model_name ())
311+ out_path = os .path .join (out_path , model .get_model_name ())
308312 os .makedirs (os .path .join (out_path , "predictions" ), exist_ok = True )
309313 if cross_study_datasets :
310314 os .makedirs (os .path .join (out_path , "cross_study" ), exist_ok = True )
@@ -316,7 +320,7 @@ def consolidate_single_drug_model_predictions(
316320 for split in range (n_cv_splits ):
317321
318322 # Collect predictions for drugs across all scenarios (main, cross_study, robustness, randomization)
319- predictions = {
323+ predictions : Any = {
320324 "main" : [],
321325 "cross_study" : {},
322326 "robustness" : {},
@@ -423,14 +427,14 @@ def consolidate_single_drug_model_predictions(
423427
424428def load_features (
425429 model : DRPModel , path_data : str , dataset : DrugResponseDataset
426- ) -> tuple [Optional [ FeatureDataset ] , Optional [FeatureDataset ]]:
430+ ) -> tuple [FeatureDataset , Optional [FeatureDataset ]]:
427431 """
428432 Load and reduce cell line and drug features for a given dataset.
429433
430434 :param model: model to use, e.g., SimpleNeuralNetwork
431435 :param path_data: path to the data directory, e.g., data/
432436 :param dataset: dataset to load features for, e.g., GDSC2
433- :returns: tuple of cell line and drug features
437+ :returns: tuple of cell line and, potentially, drug features
434438 """
435439 cl_features = model .load_cell_line_features (data_path = path_data , dataset_name = dataset .dataset_name )
436440 drug_features = model .load_drug_features (data_path = path_data , dataset_name = dataset .dataset_name )
@@ -480,10 +484,11 @@ def cross_study_prediction(
480484
481485 cell_lines_to_keep = cl_features .identifiers if cl_features is not None else None
482486
487+ drugs_to_keep : Optional [np .ndarray ] = None
483488 if single_drug_id is not None :
484489 drugs_to_keep = np .array ([single_drug_id ])
485- else :
486- drugs_to_keep = drug_features .identifiers if drug_features is not None else None
490+ elif drug_features is not None :
491+ drugs_to_keep = drug_features .identifiers
487492
488493 print (
489494 f"Reducing cross study dataset ... feature data available for "
@@ -778,12 +783,15 @@ def randomize_train_predict(
778783 )
779784 return
780785
781- cl_features_rand = cl_features .copy () if cl_features is not None else None
782- drug_features_rand = drug_features .copy () if drug_features is not None else None
783- if cl_features_rand is not None and view in cl_features .get_view_names ():
784- cl_features_rand .randomize_features (view , randomization_type = randomization_type )
785- elif drug_features_rand is not None and view in drug_features .get_view_names ():
786- drug_features_rand .randomize_features (view , randomization_type = randomization_type )
786+ cl_features_rand : Optional [FeatureDataset ] = None
787+ if cl_features is not None :
788+ cl_features_rand = cl_features .copy ()
789+ cl_features_rand .randomize_features (view , randomization_type = randomization_type ) # type: ignore[union-attr]
790+
791+ drug_features_rand : Optional [FeatureDataset ] = None
792+ if drug_features is not None :
793+ drug_features_rand = drug_features .copy ()
794+ drug_features_rand .randomize_features (view , randomization_type = randomization_type ) # type: ignore[union-attr]
787795
788796 test_dataset_rand = train_and_predict (
789797 model = model ,
@@ -1069,11 +1077,11 @@ def make_model_list(models: list[type[DRPModel]], response_data: DrugResponseDat
10691077 model_list = {}
10701078 unique_drugs = np .unique (response_data .drug_ids )
10711079 for model in models :
1072- if issubclass ( model , SingleDrugModel ) :
1080+ if model . is_single_drug_model :
10731081 for drug in unique_drugs :
1074- model_list [f"{ model .model_name } .{ drug } " ] = str ( model .model_name )
1082+ model_list [f"{ model .get_model_name () } .{ drug } " ] = model .get_model_name ( )
10751083 else :
1076- model_list [str ( model .model_name )] = str ( model .model_name )
1084+ model_list [model .get_model_name ( )] = model .get_model_name ( )
10771085 return model_list
10781086
10791087
0 commit comments