Skip to content

Commit 3afdf81

Browse files
Merge pull request #72 from daisybio/mypy_fix
Mypy fix
2 parents baa6639 + 9fe6f88 commit 3afdf81

36 files changed

+958
-561
lines changed

create_report.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,8 @@ def draw_per_grouping_algorithm_plots(
341341
custom_id=run_id,
342342
)
343343
# get all html files from results/{run_id}
344-
all_files = []
345-
for _, _, files in os.walk(f"results/{run_id}"):
344+
all_files: list[str] = []
345+
for _, _, files in os.walk(f"results/{run_id}"): # type: ignore[assignment]
346346
for file in files:
347347
if file.endswith(".html") and file not in ["index.html", "LPO.html", "LCO.html", "LDO.html"]:
348348
all_files.append(file)

drevalpy/datasets/dataset.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(
8383
if len(self.response) != len(self.drug_ids):
8484
raise AssertionError("response and drug_ids/cell_line_ids have different lengths")
8585
# Used in the pipeline!
86-
self.dataset_name = dataset_name
86+
self.dataset_name = dataset_name if dataset_name is not None else ""
8787

8888
self.predictions: Optional[np.ndarray] = None
8989
if predictions is not None:
@@ -785,14 +785,13 @@ def get_view_names(self) -> list[str]:
785785
"""
786786
return list(self.features[list(self.features.keys())[0]].keys())
787787

788-
def get_feature_matrix(self, view: str, identifiers: np.ndarray, stack: bool = True) -> np.ndarray:
788+
def get_feature_matrix(self, view: str, identifiers: np.ndarray) -> np.ndarray:
789789
"""
790790
Returns the feature matrix for the given view.
791791
792792
The feature view must be a vector or matrix.
793793
:param view: view name
794794
:param identifiers: list of identifiers (cell lines oder drugs)
795-
:param stack: if True, the feature vectors are stacked to a matrix
796795
:returns: feature matrix
797796
:raises AssertionError: if no identifiers are given
798797
:raises AssertionError: if view is not in the FeatureDataset
@@ -818,7 +817,7 @@ def get_feature_matrix(self, view: str, identifiers: np.ndarray, stack: bool = T
818817
if not all(isinstance(self.features[id_][view], np.ndarray) for id_ in identifiers):
819818
raise AssertionError(f"get_feature_matrix only works for vectors or matrices. {view} is not a numpy array.")
820819
out = np.array([self.features[id_][view] for id_ in identifiers])
821-
return np.stack(out, axis=0)
820+
return out
822821

823822
def copy(self):
824823
"""Returns a copy of the feature dataset.

drevalpy/datasets/loader.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Contains functions to load the GDSC1, GDSC2, CCLE, and Toy datasets."""
22

33
import os
4+
from typing import Callable
45

56
import pandas as pd
67

@@ -91,7 +92,12 @@ def load_toy(path_data: str = "data") -> DrugResponseDataset:
9192
)
9293

9394

94-
AVAILABLE_DATASETS = {"GDSC1": load_gdsc1, "GDSC2": load_gdsc2, "CCLE": load_ccle, "Toy_Data": load_toy}
95+
AVAILABLE_DATASETS: dict[str, Callable] = {
96+
"GDSC1": load_gdsc1,
97+
"GDSC2": load_gdsc2,
98+
"CCLE": load_ccle,
99+
"Toy_Data": load_toy,
100+
}
95101

96102

97103
@pipeline_function

drevalpy/evaluation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ def evaluate(dataset: DrugResponseDataset, metric: list[str] | str):
234234
if isinstance(metric, str):
235235
metric = [metric]
236236
predictions = dataset.predictions
237+
if predictions is None:
238+
raise AssertionError("No predictions found in the dataset")
237239
response = dataset.response
238240

239241
results = {}

drevalpy/experiment.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import shutil
66
import warnings
7-
from typing import Optional
7+
from typing import Any, Optional
88

99
import numpy as np
1010
import pandas as pd
@@ -16,7 +16,7 @@
1616
from .datasets.dataset import DrugResponseDataset, FeatureDataset
1717
from .evaluation import evaluate, get_mode
1818
from .models import MODEL_FACTORY, MULTI_DRUG_MODEL_FACTORY, SINGLE_DRUG_MODEL_FACTORY
19-
from .models.drp_model import DRPModel, SingleDrugModel
19+
from .models.drp_model import DRPModel
2020
from .pipeline_function import pipeline_function
2121

2222

@@ -82,14 +82,14 @@ def drug_response_experiment(
8282
:param test_mode: test mode one of "LPO", "LCO", "LDO" (leave-pair-out, leave-cell-line-out, leave-drug-out)
8383
:param overwrite: whether to overwrite existing results
8484
:param path_data: path to the data directory, usually data/
85+
:raises ValueError: if no cv splits are found
8586
"""
8687
if baselines is None:
8788
baselines = []
8889
cross_study_datasets = cross_study_datasets or []
8990
result_path = os.path.join(path_out, run_id, test_mode)
9091
split_path = os.path.join(result_path, "splits")
9192
result_folder_exists = os.path.exists(result_path)
92-
randomization_test_views = []
9393
if result_folder_exists and overwrite:
9494
# if results exists, delete them if overwrite is True
9595
print(f"Overwriting existing results at {result_path}")
@@ -146,6 +146,9 @@ def drug_response_experiment(
146146

147147
model_hpam_set = model_class.get_hyperparameter_set()
148148

149+
if response_data.cv_splits is None:
150+
raise ValueError("No cv splits found.")
151+
149152
for split_index, split in enumerate(response_data.cv_splits):
150153
print(f"################# FOLD {split_index+1}/{len(response_data.cv_splits)} " f"#################")
151154

@@ -233,7 +236,7 @@ def drug_response_experiment(
233236
best_hpams = json.load(f)
234237
if not is_baseline:
235238
if randomization_mode is not None:
236-
print(f"Randomization tests for {model_class.model_name}")
239+
print(f"Randomization tests for {model_class.get_model_name()}")
237240
# if this line changes, it also needs to be changed in pipeline:
238241
# randomization_split.py
239242
randomization_test_views = get_randomization_test_views(
@@ -253,7 +256,7 @@ def drug_response_experiment(
253256
response_transformation=response_transformation,
254257
)
255258
if n_trials_robustness > 0:
256-
print(f"Robustness test for {model_class.model_name}")
259+
print(f"Robustness test for {model_class.get_model_name()}")
257260
robustness_test(
258261
n_trials=n_trials_robustness,
259262
model=model,
@@ -289,7 +292,7 @@ def consolidate_single_drug_model_predictions(
289292
out_path: str = "",
290293
) -> None:
291294
"""
292-
Consolidate SingleDrugModel predictions into a single file.
295+
Consolidate single drug model predictions into a single file.
293296
294297
:param models: list of model classes to compare, e.g., [SimpleNeuralNetwork, RandomForest]
295298
:param n_cv_splits: number of cross-validation splits, e.g., 5
@@ -301,10 +304,11 @@ def consolidate_single_drug_model_predictions(
301304
will be stored in the work directory.
302305
"""
303306
for model in models:
304-
if model.model_name in SINGLE_DRUG_MODEL_FACTORY:
305-
model_instance = MODEL_FACTORY[model.model_name]()
306-
model_path = os.path.join(results_path, str(model.model_name))
307-
out_path = os.path.join(out_path, str(model.model_name))
307+
if model.get_model_name() in SINGLE_DRUG_MODEL_FACTORY:
308+
309+
model_instance = MODEL_FACTORY[model.get_model_name()]()
310+
model_path = os.path.join(results_path, model.get_model_name())
311+
out_path = os.path.join(out_path, model.get_model_name())
308312
os.makedirs(os.path.join(out_path, "predictions"), exist_ok=True)
309313
if cross_study_datasets:
310314
os.makedirs(os.path.join(out_path, "cross_study"), exist_ok=True)
@@ -316,7 +320,7 @@ def consolidate_single_drug_model_predictions(
316320
for split in range(n_cv_splits):
317321

318322
# Collect predictions for drugs across all scenarios (main, cross_study, robustness, randomization)
319-
predictions = {
323+
predictions: Any = {
320324
"main": [],
321325
"cross_study": {},
322326
"robustness": {},
@@ -423,14 +427,14 @@ def consolidate_single_drug_model_predictions(
423427

424428
def load_features(
425429
model: DRPModel, path_data: str, dataset: DrugResponseDataset
426-
) -> tuple[Optional[FeatureDataset], Optional[FeatureDataset]]:
430+
) -> tuple[FeatureDataset, Optional[FeatureDataset]]:
427431
"""
428432
Load and reduce cell line and drug features for a given dataset.
429433
430434
:param model: model to use, e.g., SimpleNeuralNetwork
431435
:param path_data: path to the data directory, e.g., data/
432436
:param dataset: dataset to load features for, e.g., GDSC2
433-
:returns: tuple of cell line and drug features
437+
:returns: tuple of cell line and, potentially, drug features
434438
"""
435439
cl_features = model.load_cell_line_features(data_path=path_data, dataset_name=dataset.dataset_name)
436440
drug_features = model.load_drug_features(data_path=path_data, dataset_name=dataset.dataset_name)
@@ -480,10 +484,11 @@ def cross_study_prediction(
480484

481485
cell_lines_to_keep = cl_features.identifiers if cl_features is not None else None
482486

487+
drugs_to_keep: Optional[np.ndarray] = None
483488
if single_drug_id is not None:
484489
drugs_to_keep = np.array([single_drug_id])
485-
else:
486-
drugs_to_keep = drug_features.identifiers if drug_features is not None else None
490+
elif drug_features is not None:
491+
drugs_to_keep = drug_features.identifiers
487492

488493
print(
489494
f"Reducing cross study dataset ... feature data available for "
@@ -778,12 +783,15 @@ def randomize_train_predict(
778783
)
779784
return
780785

781-
cl_features_rand = cl_features.copy() if cl_features is not None else None
782-
drug_features_rand = drug_features.copy() if drug_features is not None else None
783-
if cl_features_rand is not None and view in cl_features.get_view_names():
784-
cl_features_rand.randomize_features(view, randomization_type=randomization_type)
785-
elif drug_features_rand is not None and view in drug_features.get_view_names():
786-
drug_features_rand.randomize_features(view, randomization_type=randomization_type)
786+
cl_features_rand: Optional[FeatureDataset] = None
787+
if cl_features is not None:
788+
cl_features_rand = cl_features.copy()
789+
cl_features_rand.randomize_features(view, randomization_type=randomization_type) # type: ignore[union-attr]
790+
791+
drug_features_rand: Optional[FeatureDataset] = None
792+
if drug_features is not None:
793+
drug_features_rand = drug_features.copy()
794+
drug_features_rand.randomize_features(view, randomization_type=randomization_type) # type: ignore[union-attr]
787795

788796
test_dataset_rand = train_and_predict(
789797
model=model,
@@ -1069,11 +1077,11 @@ def make_model_list(models: list[type[DRPModel]], response_data: DrugResponseDat
10691077
model_list = {}
10701078
unique_drugs = np.unique(response_data.drug_ids)
10711079
for model in models:
1072-
if issubclass(model, SingleDrugModel):
1080+
if model.is_single_drug_model:
10731081
for drug in unique_drugs:
1074-
model_list[f"{model.model_name}.{drug}"] = str(model.model_name)
1082+
model_list[f"{model.get_model_name()}.{drug}"] = model.get_model_name()
10751083
else:
1076-
model_list[str(model.model_name)] = str(model.model_name)
1084+
model_list[model.get_model_name()] = model.get_model_name()
10771085
return model_list
10781086

10791087

drevalpy/models/MOLIR/molir.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
from sklearn.preprocessing import StandardScaler
1414

1515
from ...datasets.dataset import DrugResponseDataset, FeatureDataset
16-
from ..drp_model import SingleDrugModel
16+
from ..drp_model import DRPModel
1717
from ..utils import get_multiomics_feature_dataset
1818
from .utils import MOLIModel, get_dimensions_of_omics_data
1919

2020

21-
class MOLIR(SingleDrugModel):
21+
class MOLIR(DRPModel):
2222
"""
2323
Regression extension of MOLI: multi-omics late integration deep neural network.
2424
@@ -28,10 +28,10 @@ class MOLIR(SingleDrugModel):
2828
We use a regression adaption with MSE loss and a mechanism to find positive and negative samples.
2929
"""
3030

31+
is_single_drug_model = True
3132
cell_line_views = ["gene_expression", "mutations", "copy_number_variation_gistic"]
3233
drug_views = []
3334
early_stopping = True
34-
model_name = "MOLIR"
3535

3636
def __init__(self) -> None:
3737
"""
@@ -41,8 +41,17 @@ def __init__(self) -> None:
4141
gene expression, mutation and copy number variation data.
4242
"""
4343
super().__init__()
44-
self.model = None
45-
self.hyperparameters = None
44+
self.model: MOLIModel | None = None
45+
self.hyperparameters: dict[str, Any] = dict()
46+
47+
@classmethod
48+
def get_model_name(cls) -> str:
49+
"""
50+
Returns the model name.
51+
52+
:returns: MOLIR
53+
"""
54+
return "MOLIR"
4655

4756
def build_model(self, hyperparameters: dict[str, Any]) -> None:
4857
"""
@@ -68,6 +77,7 @@ def train(
6877
copy number variation data. If there is no training data, the model is set to None (and predictions will be
6978
skipped as well). If there is not enough training data, the predictions will be made on the randomly
7079
initialized model.
80+
7181
:param output: drug response data
7282
:param cell_line_input: cell line omics features, i.e., gene expression, mutations and copy number variation
7383
:param drug_input: drug features, not needed
@@ -86,7 +96,7 @@ def train(
8696
transformer=scaler_gex,
8797
view="gene_expression",
8898
)
89-
if self.early_stopping and len(output_earlystopping) < 2:
99+
if output_earlystopping is not None and self.early_stopping and len(output_earlystopping) < 2:
90100
output_earlystopping = None
91101
dim_gex, dim_mut, dim_cnv = get_dimensions_of_omics_data(cell_line_input)
92102
self.model = MOLIModel(
@@ -109,19 +119,20 @@ def train(
109119

110120
def predict(
111121
self,
112-
drug_ids: str | np.ndarray,
113-
cell_line_ids: str | np.ndarray,
122+
cell_line_ids: np.ndarray,
123+
drug_ids: np.ndarray,
124+
cell_line_input: FeatureDataset,
114125
drug_input: FeatureDataset | None = None,
115-
cell_line_input: FeatureDataset = None,
116126
) -> np.ndarray:
117127
"""
118128
Predicts the drug response.
119129
120130
If there was no training data, only nans will be returned.
121-
:param drug_ids: Drugs to predict
131+
122132
:param cell_line_ids: Cell lines to predict
123-
:param drug_input: drug features, not needed
133+
:param drug_ids: Drugs to predict
124134
:param cell_line_input: cell line omics features
135+
:param drug_input: drug features, not needed
125136
:returns: Predicted drug response
126137
"""
127138
input_data = self.get_feature_matrices(
@@ -130,9 +141,11 @@ def predict(
130141
cell_line_input=cell_line_input,
131142
drug_input=drug_input,
132143
)
133-
gene_expression = input_data["gene_expression"]
134-
mutations = input_data["mutations"]
135-
cnvs = input_data["copy_number_variation_gistic"]
144+
(gene_expression, mutations, cnvs) = (
145+
input_data["gene_expression"],
146+
input_data["mutations"],
147+
input_data["copy_number_variation_gistic"],
148+
)
136149
if self.model is None:
137150
print("No model trained, will predict NA.")
138151
return np.array([np.nan] * len(cell_line_ids))
@@ -155,3 +168,13 @@ def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureD
155168
# log transformation
156169
feature_dataset.apply(function=np.log, view="gene_expression")
157170
return feature_dataset
171+
172+
def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset | None:
173+
"""
174+
Returns None, as drug features are not needed for MOLIR.
175+
176+
:param data_path: path to the data
177+
:param dataset_name: name of the dataset
178+
:returns: None
179+
"""
180+
return None

0 commit comments

Comments
 (0)