Skip to content

Commit 7e48686

Browse files
Merge pull request #161 from daisybio/fix_dipk_molir_superfeltr
* DIPK fix for batch size 1 * MOLIR/SuperFELTR fix for if self.model is None * MOLIR/SuperFELTR: moved duplicated code to utils * MOLIR/SuperFELTR: instead of VarianceThreshold selection, we now select the top min(1000, n_features) most variable features to avoid issues when there are not enough variable features. * MultiOmicsNN: removed unused self.methylation_features test restructuring: now baseline tests, single drug model tests, neural network tests for non-single drug models * Also adapted tests for the bugs
2 parents d7bb0fc + 6a46cf1 commit 7e48686

File tree

17 files changed

+331
-453
lines changed

17 files changed

+331
-453
lines changed

.github/workflows/build_package.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ jobs:
2222

2323
- name: Install Poetry
2424
run: |
25-
pip install poetry
25+
pipx install poetry
26+
pipx inject poetry poetry-plugin-export
2627
poetry --version
2728
2829
- name: Build package

create_report.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,14 +305,16 @@ def draw_per_grouping_algorithm_plots(
305305
t_vs_p=true_vs_pred,
306306
)
307307
"""
308-
For debugging:
308+
#For debugging:
309309
evaluation_results = pd.read_csv(
310310
f'results/{run_id}/evaluation_results.csv', index_col=0
311311
)
312312
evaluation_results_per_drug = pd.read_csv(
313313
f'results/{run_id}/evaluation_results_per_drug.csv', index_col=0
314314
)
315-
evaluation_results_per_cell_line = None
315+
evaluation_results_per_cell_line = pd.read_csv(
316+
f'results/{run_id}/evaluation_results_per_cl.csv', index_col=0
317+
)
316318
true_vs_pred = pd.read_csv(
317319
f'results/{run_id}/true_vs_pred.csv', index_col=0
318320
)

docs/conf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@
5656
# the built documents.
5757
#
5858
# The short X.Y version.
59-
version = "1.2.5"
59+
version = "1.2.6"
6060
# The full version, including alpha/beta/rc tags.
61-
release = "1.2.5"
61+
release = "1.2.6"
6262

6363
# The language for content autogenerated by Sphinx. Refer to documentation
6464
# for a list of supported languages.

drevalpy/datasets/dataset.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import numpy as np
2323
import pandas as pd
2424
from sklearn.base import TransformerMixin
25-
from sklearn.feature_selection import VarianceThreshold
2625
from sklearn.model_selection import GroupKFold, train_test_split
2726

2827
from ..pipeline_function import pipeline_function
@@ -1003,9 +1002,6 @@ def fit_transform_features(self, train_ids: np.ndarray, transformer: Transformer
10031002
# Collect all features of the view for fitting the scaler
10041003
train_features = np.vstack([self.features[identifier][view] for identifier in train_ids])
10051004
transformer.fit(train_features)
1006-
if isinstance(transformer, VarianceThreshold):
1007-
mask = transformer.get_support()
1008-
self.meta_info[view] = self.meta_info[view][mask]
10091005

10101006
# Apply transformation and scaling to each feature vector
10111007
for identifier in self.features:

drevalpy/models/DIPK/dipk.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,8 +313,10 @@ def predict(
313313
bionic=bionic_features,
314314
molgnet_mask=molgnet_mask,
315315
)
316-
predictions += torch.squeeze(prediction).cpu().tolist()
317-
316+
if prediction.numel() > 1:
317+
predictions += torch.squeeze(prediction).cpu().tolist()
318+
else:
319+
predictions += [prediction.item()]
318320
return np.array(predictions)
319321

320322
def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset:

drevalpy/models/MOLIR/molir.py

Lines changed: 13 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,12 @@
99
from typing import Any
1010

1111
import numpy as np
12-
from sklearn.feature_selection import VarianceThreshold
1312
from sklearn.preprocessing import StandardScaler
1413

1514
from ...datasets.dataset import DrugResponseDataset, FeatureDataset
1615
from ..drp_model import DRPModel
1716
from ..utils import get_multiomics_feature_dataset
18-
from .utils import MOLIModel, get_dimensions_of_omics_data
17+
from .utils import MOLIModel, filter_and_sort_omics, get_dimensions_of_omics_data, select_features_for_view
1918

2019

2120
class MOLIR(DRPModel):
@@ -76,8 +75,9 @@ def train(
7675
"""
7776
Initializes and trains the model.
7877
79-
First, the gene expression data is reduced using a variance threshold (0.05) and standardized. Then,
80-
the model is initialized with the hyperparameters and the dimensions of the gene expression, mutation and
78+
First, the gene expression data was reduced using a variance threshold (0.05) and standardized. We chose to use
79+
the most variable 1000 genes instead to avoid issues with the variance threshold.
80+
Then, the model is initialized with the hyperparameters and the dimensions of the gene expression, mutation and
8181
copy number variation data. If there is no training data, the model is set to None (and predictions will be
8282
skipped as well). If there is not enough training data, the predictions will be made on the randomly
8383
initialized model.
@@ -89,11 +89,10 @@ def train(
8989
:param model_checkpoint_dir: directory to save the model checkpoints
9090
"""
9191
if len(output) > 0:
92-
selector_gex = VarianceThreshold(0.05)
93-
cell_line_input.fit_transform_features(
94-
train_ids=np.unique(output.cell_line_ids),
95-
transformer=selector_gex,
92+
cell_line_input = select_features_for_view(
9693
view="gene_expression",
94+
cell_line_input=cell_line_input,
95+
output=output,
9796
)
9897
self.gene_expression_features = cell_line_input.meta_info["gene_expression"]
9998
self.mutations_features = cell_line_input.meta_info["mutations"]
@@ -145,6 +144,9 @@ def predict(
145144
:returns: Predicted drug response
146145
:raises ValueError: If the model was not trained
147146
"""
147+
if self.model is None:
148+
print("No model trained, will predict NA.")
149+
return np.array([np.nan] * len(cell_line_ids))
148150
if (
149151
(self.gene_expression_features is None)
150152
or (self.mutations_features is None)
@@ -164,37 +166,10 @@ def predict(
164166
input_data["copy_number_variation_gistic"],
165167
)
166168

167-
# Filter out features that were not present during training
168-
# This is necessary because the feature order might have changed
169-
# or more features are available
170-
# impute missing features with zeros
171-
for key, features in {
172-
"gene_expression": self.gene_expression_features,
173-
"mutations": self.mutations_features,
174-
"copy_number_variation_gistic": self.copy_number_variation_features,
175-
}.items():
176-
if key == "gene_expression":
177-
values = gene_expression
178-
elif key == "mutations":
179-
values = mutations
180-
else:
181-
values = cnvs
182-
if values.shape[1] != len(features):
183-
new_value = np.zeros((values.shape[0], len(features)))
184-
lookup_table = {feature: i for i, feature in enumerate(cell_line_input.meta_info[key])}
185-
for i, feature in enumerate(features):
186-
if feature in lookup_table:
187-
new_value[:, i] = values[:, lookup_table[feature]]
188-
if key == "gene_expression":
189-
gene_expression = new_value
190-
elif key == "mutations":
191-
mutations = new_value
192-
else:
193-
cnvs = new_value
169+
(gene_expression, mutations, cnv) = filter_and_sort_omics(
170+
model=self, gene_expression=gene_expression, mutations=mutations, cnvs=cnvs, cell_line_input=cell_line_input
171+
)
194172

195-
if self.model is None:
196-
print("No model trained, will predict NA.")
197-
return np.array([np.nan] * len(cell_line_ids))
198173
return self.model.predict(gene_expression, mutations, cnvs)
199174

200175
def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset:

drevalpy/models/MOLIR/utils.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torch.utils.data import DataLoader, Dataset
1919

2020
from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset
21+
from drevalpy.models.drp_model import DRPModel
2122

2223

2324
class RegressionDataset(Dataset):
@@ -205,6 +206,76 @@ def get_dimensions_of_omics_data(cell_line_input: FeatureDataset) -> tuple[int,
205206
return dim_gex, dim_mut, dim_cnv
206207

207208

209+
def filter_and_sort_omics(
210+
model: DRPModel, # MOLIR or SuperFELTR
211+
gene_expression: np.ndarray,
212+
mutations: np.ndarray,
213+
cnvs: np.ndarray,
214+
cell_line_input: FeatureDataset,
215+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
216+
"""
217+
Filters out features that were not present during training and imputes missing features with zeros.
218+
219+
This is necessary because the feature order might have changed or more features are available (cross-study setting).
220+
221+
:param model: either MOLIR or SuperFELTR self
222+
:param gene_expression: new gene expression data from which to predict
223+
:param mutations: new mutation data from which to predict
224+
:param cnvs: new copy number variation data from which to predict
225+
:param cell_line_input: needed for meta information (feature names)
226+
:return: filtered and sorted gene expression, mutations, and copy number variation data
227+
"""
228+
for key, features in {
229+
"gene_expression": model.gene_expression_features, # type: ignore
230+
"mutations": model.mutations_features, # type: ignore
231+
"copy_number_variation_gistic": model.copy_number_variation_features, # type: ignore
232+
}.items():
233+
if key == "gene_expression":
234+
values = gene_expression
235+
elif key == "mutations":
236+
values = mutations
237+
else:
238+
values = cnvs
239+
if values.shape[1] != len(features):
240+
new_value = np.zeros((values.shape[0], len(features)))
241+
lookup_table = {feature: i for i, feature in enumerate(cell_line_input.meta_info[key])}
242+
for i, feature in enumerate(features):
243+
if feature in lookup_table:
244+
new_value[:, i] = values[:, lookup_table[feature]]
245+
if key == "gene_expression":
246+
gene_expression = new_value
247+
elif key == "mutations":
248+
mutations = new_value
249+
else:
250+
cnvs = new_value
251+
return gene_expression, mutations, cnvs
252+
253+
254+
def select_features_for_view(
255+
view: str, # "gene_expression", "mutations", or "copy_number_variation_gistic"
256+
cell_line_input: FeatureDataset,
257+
output: DrugResponseDataset,
258+
):
259+
"""
260+
Selects the top 1000 features with the highest variance for the omics data.
261+
262+
:param view: either "gene_expression", "mutations", or "copy_number_variation_gistic"
263+
:param cell_line_input: the omics data of the cell lines
264+
:param output: the training dataset containing the response output
265+
:return: the modified cell line input with the top 1000 features with the highest variance
266+
"""
267+
train_features = np.vstack(
268+
[cell_line_input.features[identifier][view] for identifier in np.unique(output.cell_line_ids)]
269+
)
270+
variances = np.var(train_features, axis=0)
271+
mask = np.zeros(len(variances), dtype=bool)
272+
mask[np.argsort(variances)[::-1][:1000]] = True
273+
cell_line_input.meta_info[view] = cell_line_input.meta_info[view][mask]
274+
for identifier in cell_line_input.features:
275+
cell_line_input.features[identifier][view] = cell_line_input.features[identifier][view][mask]
276+
return cell_line_input
277+
278+
208279
class MOLIEncoder(nn.Module):
209280
"""
210281
Encoders of the MOLIR model, which is identical to the encoders of the original MOLI model.

drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ def __init__(self):
3434
self.model = None
3535
self.hyperparameters = None
3636
self.pca = None
37-
self.methylation_features = None
3837

3938
@classmethod
4039
def get_model_name(cls) -> str:
@@ -83,7 +82,6 @@ def train(
8382
[cell_line_input.features[id_]["methylation"] for id_ in np.unique(output.cell_line_ids)],
8483
axis=0,
8584
)
86-
self.methylation_features = cell_line_input.meta_info["methylation"]
8785

8886
self.pca.n_components = min(self.pca.n_components, len(unique_methylation))
8987
self.pca = self.pca.fit(unique_methylation)

drevalpy/models/SuperFELTR/hyperparameters.yaml

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,29 +8,5 @@ SuperFELTR:
88
out_dim_mutation_encoder: 32
99
out_dim_cnv_encoder: 64
1010
epochs: 30
11-
expression_var_threshold:
12-
GDSC1: 0.1
13-
GDSC2: 0.1
14-
TOYv1: 0.03
15-
TOYv2: 0.03
16-
CCLE: 0.1
17-
CTRPv1: 0.1
18-
CTRPv2: 0.1
19-
mutation_var_threshold:
20-
GDSC1: 0.1
21-
GDSC2: 0.1
22-
TOYv1: 0.05
23-
TOYv2: 0.05
24-
CCLE: 0.1
25-
CTRPv1: 0.1
26-
CTRPv2: 0.1
27-
cnv_var_threshold:
28-
GDSC1: 0.7
29-
GDSC2: 0.7
30-
TOYv1: 0.6
31-
TOYv2: 0.6
32-
CCLE: 0.7
33-
CTRPv1: 0.7
34-
CTRPv2: 0.7
3511
margin: 1.0
3612
learning_rate: 0.01

drevalpy/models/SuperFELTR/superfeltr.py

Lines changed: 15 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,10 @@
1919

2020
import numpy as np
2121
import pytorch_lightning as pl
22-
from sklearn.feature_selection import VarianceThreshold
2322

2423
from ...datasets.dataset import DrugResponseDataset, FeatureDataset
2524
from ..drp_model import DRPModel
26-
from ..MOLIR.utils import get_dimensions_of_omics_data, make_ranges
25+
from ..MOLIR.utils import filter_and_sort_omics, get_dimensions_of_omics_data, make_ranges, select_features_for_view
2726
from ..utils import get_multiomics_feature_dataset
2827
from .utils import SuperFELTEncoder, SuperFELTRegressor, train_superfeltr_model
2928

@@ -201,6 +200,9 @@ def predict(
201200
:returns: predicted drug response
202201
:raises ValueError: if drug_input is not None
203202
"""
203+
if self.expr_encoder is None or self.mut_encoder is None or self.cnv_encoder is None or self.regressor is None:
204+
print("No training data was available, predicting NA")
205+
return np.array([np.nan] * len(cell_line_ids))
204206
if (
205207
self.gene_expression_features is None
206208
or self.mutations_features is None
@@ -223,35 +225,10 @@ def predict(
223225
input_data["copy_number_variation_gistic"],
224226
)
225227

226-
# make cross study prediction possible by selecting only the features that were used during training
227-
# missing features are imputed with zeros
228-
for key, features in {
229-
"gene_expression": self.gene_expression_features,
230-
"mutations": self.mutations_features,
231-
"copy_number_variation_gistic": self.copy_number_variation_features,
232-
}.items():
233-
if key == "gene_expression":
234-
values = gene_expression
235-
elif key == "mutations":
236-
values = mutations
237-
else:
238-
values = cnvs
239-
if values.shape[1] != len(features):
240-
new_value = np.zeros((values.shape[0], len(features)))
241-
lookup_table = {feature: i for i, feature in enumerate(cell_line_input.meta_info[key])}
242-
for i, feature in enumerate(features):
243-
if feature in lookup_table:
244-
new_value[:, i] = values[:, lookup_table[feature]]
245-
if key == "gene_expression":
246-
gene_expression = new_value
247-
elif key == "mutations":
248-
mutations = new_value
249-
else:
250-
cnvs = new_value
228+
(gene_expression, mutations, cnvs) = filter_and_sort_omics(
229+
model=self, gene_expression=gene_expression, mutations=mutations, cnvs=cnvs, cell_line_input=cell_line_input
230+
)
251231

252-
if self.expr_encoder is None or self.mut_encoder is None or self.cnv_encoder is None or self.regressor is None:
253-
print("No training data was available, predicting NA")
254-
return np.array([np.nan] * len(cell_line_ids))
255232
if self.best_checkpoint is None:
256233
print("Not enough training data provided for SuperFELTR Regressor. Predicting with random initialization.")
257234
return self.regressor.predict(gene_expression, mutations, cnvs)
@@ -260,21 +237,20 @@ def predict(
260237

261238
def _feature_selection(self, output: DrugResponseDataset, cell_line_input: FeatureDataset) -> FeatureDataset:
262239
"""
263-
Feature selection for all omics data using the predefined variance thresholds.
240+
Feature selection for all omics data.
241+
242+
Originally, this was done with VarianceThreshold but as data can vary and hence the thresholds are not
243+
universally applicable, we now changed it to select the top 1000 variable features for each omics data.
264244
265245
:param output: training data associated with the response output
266246
:param cell_line_input: cell line omics features
267247
:returns: cell line omics features with selected features
268248
"""
269-
thresholds = {
270-
"gene_expression": self.hyperparameters["expression_var_threshold"][output.dataset_name],
271-
"mutations": self.hyperparameters["mutation_var_threshold"][output.dataset_name],
272-
"copy_number_variation_gistic": self.hyperparameters["cnv_var_threshold"][output.dataset_name],
273-
}
274249
for view in self.cell_line_views:
275-
selector = VarianceThreshold(thresholds[view])
276-
cell_line_input.fit_transform_features(
277-
train_ids=np.unique(output.cell_line_ids), transformer=selector, view=view
250+
cell_line_input = select_features_for_view(
251+
view=view,
252+
cell_line_input=cell_line_input,
253+
output=output,
278254
)
279255
self.gene_expression_features = cell_line_input.meta_info["gene_expression"]
280256
self.mutations_features = cell_line_input.meta_info["mutations"]

0 commit comments

Comments
 (0)