Skip to content

Commit 7770e0f

Browse files
committed
Merge branch 'DIPK_fix_genes_for_cs' of github.com:daisybio/drevalpy into DIPK_fix_genes_for_cs
2 parents 3109aeb + b300851 commit 7770e0f

File tree

18 files changed

+180
-86
lines changed

18 files changed

+180
-86
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ data/mapping
55
data/GDSC1
66
data/GDSC2
77
data/CCLE
8-
data/Toy_Data
8+
data/TOYv1
9+
data/TOYv2
910
data/CTRPv1
1011
data/CTRPv2
1112

docs/quickstart.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@ Quickstart
33

44
Make sure you have installed DrEvalPy and its dependencies (see `Installation <./installation.html>`_).
55

6-
To make sure the pipeline runs, you can use the fast models NaiveDrugMeanPredictor and NaivePredictor on the Toy_Data
6+
To make sure the pipeline runs, you can use the fast models NaiveDrugMeanPredictor and NaivePredictor on the TOYv1 (subset of CTRPv2) or TOYv2 (subset of GDSC2)
77
dataset with the LPO test mode.
88

99
.. code-block:: bash
1010
11-
python run_suite.py --run_id my_first_run --models NaiveDrugMeanPredictor --baselines NaivePredictor --dataset Toy_Data --test_mode LPO
11+
python run_suite.py --run_id my_first_run --models NaiveDrugMeanPredictor --baselines NaivePredictor --dataset TOYv1 --test_mode LPO
1212
1313
This will train the two baseline models on a subset of gene expression features and drug fingerprint features to
1414
predict IC50 values of the GDSC1 database. It will evaluate in "LPO" which is the leave-pairs-out splitting strategy

docs/usage.rst

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,14 +156,21 @@ We provide commonly used datasets to evaluate your model on (GDSC1, GDSC2, CCLE,
156156
+-------------------+-----------------+---------------------+-----------------------------------------------------------------------------------------------------------------------+
157157
| Dataset Name | Number of Drugs | Number of Cell Lines| Description |
158158
+===================+=================+=====================+=======================================================================================================================+
159-
| GDSC1 | 345 | 987 | The Genomics of Drug Sensitivity in Cancer (GDSC) dataset version 1. |
159+
| GDSC1 | 378 | 970 | The Genomics of Drug Sensitivity in Cancer (GDSC) dataset version 1. |
160160
+-------------------+-----------------+---------------------+-----------------------------------------------------------------------------------------------------------------------+
161-
| GDSC2 | 192 | 809 | The Genomics of Drug Sensitivity in Cancer (GDSC) dataset version 2. |
161+
| GDSC2 | 287 | 969 | The Genomics of Drug Sensitivity in Cancer (GDSC) dataset version 2. |
162162
+-------------------+-----------------+---------------------+-----------------------------------------------------------------------------------------------------------------------+
163-
| CCLE | 18 | 471 | The Cancer Cell Line Encyclopedia (CCLE) dataset. The response data will soon be replaced with the data from CTRPv2. |
163+
| CCLE | 24 | 503 | The Cancer Cell Line Encyclopedia (CCLE) dataset. |
164164
+-------------------+-----------------+---------------------+-----------------------------------------------------------------------------------------------------------------------+
165-
| Toy_Data | 40 | 98 | A toy dataset for testing purposes. |
165+
| CTRPv1 | 354 | 243 | The Cancer Therapeutics Response Portal (CTRP) dataset version 1. |
166166
+-------------------+-----------------+---------------------+-----------------------------------------------------------------------------------------------------------------------+
167+
| CTRPv2 | 546 | 886 | The Cancer Therapeutics Response Portal (CTRP) dataset version 2. |
168+
+-------------------+-----------------+---------------------+-----------------------------------------------------------------------------------------------------------------------+
169+
| TOYv1 | 36 | 90 | A toy dataset for testing purposes subsetted from CTRPv2. |
170+
+-------------------+-----------------+---------------------+-----------------------------------------------------------------------------------------------------------------------+
171+
| TOYv2 | 36 | 90 | A second toy dataset for cross study testing purposes. 80 cell lines and 32 drugs overlap TOYv2. |
172+
+-------------------+-----------------+---------------------+-----------------------------------------------------------------------------------------------------------------------+
173+
167174
168175
If using the ``--curve_curator`` option with these datasets, the desired measure provided with the ``--measure`` option is appended with "_curvecurator", e.g. "IC50_curvecurator".
169176
In the provided datasets, these are the measures calculated with the same fitting procedure using CurveCurator. To use the measures reported from the original publications of the

drevalpy/datasets/loader.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def load_gdsc1(
2323
2424
:param path_data: Path to the dataset.
2525
:param file_name: File name of the dataset.
26-
:param measure: The name of the column containing the measure to predict, default = "LN_IC50"
26+
:param measure: The name of the column containing the measure to predict, default = "LN_IC50_curvecurator"
2727
2828
:param dataset_name: Name of the dataset.
2929
:return: DrugResponseDataset containing response, cell line IDs, and drug IDs.
@@ -49,7 +49,7 @@ def load_gdsc2(path_data: str = "data", measure: str = "LN_IC50_curvecurator", f
4949
5050
:param path_data: Path to the dataset.
5151
:param file_name: File name of the dataset.
52-
:param measure: The name of the column containing the measure to predict, default = "LN_IC50"
52+
:param measure: The name of the column containing the measure to predict, default = "LN_IC50_curvecurator"
5353
5454
:return: DrugResponseDataset containing response, cell line IDs, and drug IDs.
5555
"""
@@ -64,7 +64,7 @@ def load_ccle(
6464
6565
:param path_data: Path to the dataset.
6666
:param file_name: File name of the dataset.
67-
:param measure: The name of the column containing the measure to predict, default = "LN_IC50"
67+
:param measure: The name of the column containing the measure to predict, default = "LN_IC50_curvecurator"
6868
6969
:return: DrugResponseDataset containing response, cell line IDs, and drug IDs.
7070
"""
@@ -84,17 +84,19 @@ def load_ccle(
8484
)
8585

8686

87-
def load_toy(path_data: str = "data", measure: str = "LN_IC50_curvecurator") -> DrugResponseDataset:
87+
def _load_toy(
88+
path_data: str = "data", measure: str = "LN_IC50_curvecurator", dataset_name="TOYv1"
89+
) -> DrugResponseDataset:
8890
"""
89-
Loads small Toy dataset, subsampled from GDSC1.
91+
Loads small Toy dataset, subsampled from CTRPv2 or GDSC2.
9092
9193
:param path_data: Path to the dataset.
92-
:param measure: The name of the column containing the measure to predict, default = "response"
94+
:param measure: The name of the column containing the measure to predict, default = "LN_IC50_curvecurator"
95+
:param dataset_name: Name of the dataset. Either "TOYv1" or "TOYv2".
9396
9497
:return: DrugResponseDataset containing response, cell line IDs, and drug IDs.
9598
"""
96-
dataset_name = "Toy_Data"
97-
path = os.path.join(path_data, dataset_name, "toy_data.csv")
99+
path = os.path.join(path_data, dataset_name, f"{dataset_name}.csv")
98100
if not os.path.exists(path):
99101
download_dataset(dataset_name, path_data, redownload=True)
100102
response_data = pd.read_csv(path, dtype={"pubchem_id": str})
@@ -107,13 +109,37 @@ def load_toy(path_data: str = "data", measure: str = "LN_IC50_curvecurator") ->
107109
)
108110

109111

112+
def load_toyv1(path_data: str = "data", measure: str = "LN_IC50_curvecurator") -> DrugResponseDataset:
113+
"""
114+
Loads small Toy dataset, subsampled from CTRPv2.
115+
116+
:param path_data: Path to the dataset.
117+
:param measure: The name of the column containing the measure to predict, default = "LN_IC50_curvecurator"
118+
119+
:return: DrugResponseDataset containing response, cell line IDs, and drug IDs.
120+
"""
121+
return _load_toy(path_data, measure, "TOYv1")
122+
123+
124+
def load_toyv2(path_data: str = "data", measure: str = "LN_IC50_curvecurator") -> DrugResponseDataset:
125+
"""
126+
Loads small Toy dataset, subsampled from GDSC2. Can be used to test cross study prediction.
127+
128+
:param path_data: Path to the dataset.
129+
:param measure: The name of the column containing the measure to predict, default = "LN_IC50_curvecurator"
130+
131+
:return: DrugResponseDataset containing response, cell line IDs, and drug IDs.
132+
"""
133+
return _load_toy(path_data, measure, "TOYv2")
134+
135+
110136
def _load_ctrpv(version: str, path_data: str = "data", measure: str = "LN_IC50_curvecurator") -> DrugResponseDataset:
111137
"""
112138
Load CTRPv1 dataset.
113139
114140
:param version: The version of the CTRP dataset to load.
115141
:param path_data: Path to location of CTRPv1 dataset
116-
:param measure: The name of the column containing the measure to predict, default = "response"
142+
:param measure: The name of the column containing the measure to predict, default = "LN_IC50_curvecurator"
117143
118144
:return: DrugResponseDataset containing response, cell line IDs, and drug IDs
119145
"""
@@ -171,7 +197,8 @@ def load_custom(path_data: str | Path, measure: str = "response") -> DrugRespons
171197
"GDSC1": load_gdsc1,
172198
"GDSC2": load_gdsc2,
173199
"CCLE": load_ccle,
174-
"Toy_Data": load_toy,
200+
"TOYv1": load_toyv1,
201+
"TOYv2": load_toyv2,
175202
"CTRPv1": load_ctrpv1,
176203
"CTRPv2": load_ctrpv2,
177204
}
@@ -184,7 +211,7 @@ def load_dataset(
184211
"""
185212
Load a dataset based on the dataset name.
186213
187-
:param dataset_name: The name of the dataset to load. Can be one of ('GDSC1', 'GDSC2', 'CCLE', or 'Toy_Data')
214+
:param dataset_name: The name of the dataset to load. Can be one of ('GDSC1', 'GDSC2', 'CCLE', 'TOYv1', or 'TOYv2')
188215
to download provided datasets, or any other name to allow for custom datasets.
189216
:param path_data: The parent path in which custom or downloaded datasets should be located, or in which raw
190217
viability data is to be found for fitting with CurveCurator (see param curve_curator for details).

drevalpy/datasets/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def download_dataset(
2121
"""
2222
Download the latets dataset from Zenodo.
2323
24-
:param dataset_name: dataset name, e.g., "GDSC1", "GDSC2", "CCLE" or "Toy_Data"
24+
:param dataset_name: dataset name, from "GDSC1", "GDSC2", "CCLE", "CTRPv1", "CTRPv2", "TOYv1", "TOYv2"
2525
:param data_path: where to save the data
2626
:param redownload: whether to redownload the data
2727
:raises HTTPError: if the download fails

drevalpy/models/DIPK/dipk.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,18 @@ def predict(
270270
if not isinstance(self.model, Predictor):
271271
raise ValueError("DIPK model not initialized.")
272272

273+
# Encode gene expression data if this has not been done yet (e.g., for cross-study predictions)
274+
random_cell_line = next(iter(cell_line_input.features.keys()))
275+
if (
276+
len(cell_line_input.features[random_cell_line]["gene_expression"])
277+
!= self.gene_expression_encoder.latent_dim
278+
):
279+
print("Encoding gene expression data for cross study prediction")
280+
cell_line_input.apply(
281+
lambda x: encode_gene_expression(x, self.gene_expression_encoder), # type: ignore[arg-type]
282+
view="gene_expression",
283+
) # type: ignore[arg-type]
284+
273285
# Load data
274286
collate = CollateFn(train=False)
275287
test_samples = get_data(
@@ -314,7 +326,7 @@ def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureD
314326
# in the gene expression features of all datasets
315327
gene_expression = load_and_reduce_gene_features(
316328
feature_type="gene_expression",
317-
gene_list="gene_expression_genes_intercept_all_datasets" if dataset_name != "Toy_Data" else None,
329+
gene_list="gene_expression_intersection",
318330
data_path=data_path,
319331
dataset_name=dataset_name,
320332
)

drevalpy/models/DIPK/gene_expression_encoder.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@
1111
from torch.nn import functional
1212
from torch.utils.data import DataLoader, Dataset
1313

14-
ldim = 512
15-
hdim = [2048, 1024]
16-
1714

1815
class GeneExpressionEncoder(nn.Module):
1916
"""Gene expression encoder.
@@ -22,7 +19,7 @@ class GeneExpressionEncoder(nn.Module):
2219
DIPK model https://github.com/user15632/DIPK.
2320
"""
2421

25-
def __init__(self, input_dim, latent_dim=ldim, h_dims=None, drop_out_rate=0.3):
22+
def __init__(self, input_dim, latent_dim=512, h_dims=None, drop_out_rate=0.3):
2623
"""Initialize the gene expression encoder.
2724
2825
:param input_dim: input dimension
@@ -32,7 +29,7 @@ def __init__(self, input_dim, latent_dim=ldim, h_dims=None, drop_out_rate=0.3):
3229
"""
3330
super().__init__()
3431
if h_dims is None:
35-
h_dims = hdim
32+
h_dims = [2048, 1024]
3633
hidden_dims = deepcopy(h_dims)
3734
hidden_dims.insert(0, input_dim)
3835
modules = []
@@ -47,6 +44,7 @@ def __init__(self, input_dim, latent_dim=ldim, h_dims=None, drop_out_rate=0.3):
4744
)
4845
self.encoder = nn.Sequential(*modules)
4946
self.bottleneck = nn.Linear(hidden_dims[-1], latent_dim)
47+
self.latent_dim = latent_dim
5048

5149
def forward(self, input):
5250
"""Forward pass of the gene expression encoder.
@@ -62,7 +60,7 @@ def forward(self, input):
6260
class GeneExpressionDecoder(nn.Module):
6361
"""Gene expression decoder."""
6462

65-
def __init__(self, input_dim, latent_dim=ldim, h_dims=None, drop_out_rate=0.3):
63+
def __init__(self, input_dim, latent_dim=512, h_dims=None, drop_out_rate=0.3):
6664
"""Initialize the gene expression decoder.
6765
6866
:param input_dim: input dimension
@@ -72,7 +70,7 @@ def __init__(self, input_dim, latent_dim=ldim, h_dims=None, drop_out_rate=0.3):
7270
"""
7371
super().__init__()
7472
if h_dims is None:
75-
h_dims = hdim
73+
h_dims = [2048, 1024]
7674
hidden_dims = deepcopy(h_dims)
7775
hidden_dims.insert(0, input_dim)
7876
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1])

drevalpy/models/MOLIR/molir.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,15 @@ def predict(
143143
:param cell_line_input: cell line omics features
144144
:param drug_input: drug features, not needed
145145
:returns: Predicted drug response
146+
:raises ValueError: If the model was not trained
146147
"""
148+
if (
149+
(self.gene_expression_features is None)
150+
or (self.mutations_features is None)
151+
or (self.copy_number_variation_features is None)
152+
):
153+
raise ValueError("MOLIR Model not trained, please train the model first.")
154+
147155
input_data = self.get_feature_matrices(
148156
cell_line_ids=cell_line_ids,
149157
drug_ids=drug_ids,
@@ -156,6 +164,10 @@ def predict(
156164
input_data["copy_number_variation_gistic"],
157165
)
158166

167+
# Filter out features that were not present during training
168+
# This is necessary because the feature order might have changed
169+
# or more features are available
170+
# impute missing features with zeros
159171
for key, features in {
160172
"gene_expression": self.gene_expression_features,
161173
"mutations": self.mutations_features,
@@ -199,7 +211,7 @@ def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureD
199211
gene_lists={
200212
"gene_expression": "gene_expression_intersection",
201213
"mutations": "mutations_intersection",
202-
"copy_number_variation_gistic": "copy_number_variation_intersection",
214+
"copy_number_variation_gistic": "copy_number_variation_gistic_intersection",
203215
},
204216
omics=self.cell_line_views,
205217
)

drevalpy/models/SuperFELTR/hyperparameters.yaml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,24 @@ SuperFELTR:
1111
expression_var_threshold:
1212
GDSC1: 0.1
1313
GDSC2: 0.1
14-
Toy_Data: 0.03
14+
TOYv1: 0.03
15+
TOYv2: 0.03
1516
CCLE: 0.1
1617
CTRPv1: 0.1
1718
CTRPv2: 0.1
1819
mutation_var_threshold:
1920
GDSC1: 0.1
2021
GDSC2: 0.1
21-
Toy_Data: 0.05
22+
TOYv1: 0.05
23+
TOYv2: 0.05
2224
CCLE: 0.1
2325
CTRPv1: 0.1
2426
CTRPv2: 0.1
2527
cnv_var_threshold:
2628
GDSC1: 0.7
2729
GDSC2: 0.7
28-
Toy_Data: 0.6
30+
TOYv1: 0.6
31+
TOYv2: 0.6
2932
CCLE: 0.7
3033
CTRPv1: 0.7
3134
CTRPv2: 0.7

drevalpy/models/SuperFELTR/superfeltr.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,13 @@ def predict(
201201
:returns: predicted drug response
202202
:raises ValueError: if drug_input is not None
203203
"""
204+
if (
205+
self.gene_expression_features is None
206+
or self.mutations_features is None
207+
or self.copy_number_variation_features is None
208+
):
209+
raise ValueError("Model was not trained, no features available.")
210+
204211
if drug_input is not None:
205212
raise ValueError("SuperFELTR is a single drug model and does not require drug input.")
206213

@@ -216,6 +223,8 @@ def predict(
216223
input_data["copy_number_variation_gistic"],
217224
)
218225

226+
# make cross study prediction possible by selecting only the features that were used during training
227+
# missing features are imputed with zeros
219228
for key, features in {
220229
"gene_expression": self.gene_expression_features,
221230
"mutations": self.mutations_features,

0 commit comments

Comments
 (0)