1919
2020import numpy as np
2121import pytorch_lightning as pl
22- from sklearn .feature_selection import VarianceThreshold
2322
2423from ...datasets .dataset import DrugResponseDataset , FeatureDataset
2524from ..drp_model import DRPModel
26- from ..MOLIR .utils import get_dimensions_of_omics_data , make_ranges
25+ from ..MOLIR .utils import filter_and_sort_omics , get_dimensions_of_omics_data , make_ranges , select_features_for_view
2726from ..utils import get_multiomics_feature_dataset
2827from .utils import SuperFELTEncoder , SuperFELTRegressor , train_superfeltr_model
2928
@@ -201,6 +200,9 @@ def predict(
201200 :returns: predicted drug response
202201 :raises ValueError: if drug_input is not None
203202 """
203+ if self .expr_encoder is None or self .mut_encoder is None or self .cnv_encoder is None or self .regressor is None :
204+ print ("No training data was available, predicting NA" )
205+ return np .array ([np .nan ] * len (cell_line_ids ))
204206 if (
205207 self .gene_expression_features is None
206208 or self .mutations_features is None
@@ -223,35 +225,10 @@ def predict(
223225 input_data ["copy_number_variation_gistic" ],
224226 )
225227
226- # make cross study prediction possible by selecting only the features that were used during training
227- # missing features are imputed with zeros
228- for key , features in {
229- "gene_expression" : self .gene_expression_features ,
230- "mutations" : self .mutations_features ,
231- "copy_number_variation_gistic" : self .copy_number_variation_features ,
232- }.items ():
233- if key == "gene_expression" :
234- values = gene_expression
235- elif key == "mutations" :
236- values = mutations
237- else :
238- values = cnvs
239- if values .shape [1 ] != len (features ):
240- new_value = np .zeros ((values .shape [0 ], len (features )))
241- lookup_table = {feature : i for i , feature in enumerate (cell_line_input .meta_info [key ])}
242- for i , feature in enumerate (features ):
243- if feature in lookup_table :
244- new_value [:, i ] = values [:, lookup_table [feature ]]
245- if key == "gene_expression" :
246- gene_expression = new_value
247- elif key == "mutations" :
248- mutations = new_value
249- else :
250- cnvs = new_value
228+ (gene_expression , mutations , cnvs ) = filter_and_sort_omics (
229+ model = self , gene_expression = gene_expression , mutations = mutations , cnvs = cnvs , cell_line_input = cell_line_input
230+ )
251231
252- if self .expr_encoder is None or self .mut_encoder is None or self .cnv_encoder is None or self .regressor is None :
253- print ("No training data was available, predicting NA" )
254- return np .array ([np .nan ] * len (cell_line_ids ))
255232 if self .best_checkpoint is None :
256233 print ("Not enough training data provided for SuperFELTR Regressor. Predicting with random initialization." )
257234 return self .regressor .predict (gene_expression , mutations , cnvs )
@@ -260,21 +237,20 @@ def predict(
260237
261238 def _feature_selection (self , output : DrugResponseDataset , cell_line_input : FeatureDataset ) -> FeatureDataset :
262239 """
263- Feature selection for all omics data using the predefined variance thresholds.
240+ Feature selection for all omics data.
241+
242+ Originally, this was done with VarianceThreshold but as data can vary and hence the thresholds are not
243+ universally applicable, we now changed it to select the top 1000 variable features for each omics data.
264244
265245 :param output: training data associated with the response output
266246 :param cell_line_input: cell line omics features
267247 :returns: cell line omics features with selected features
268248 """
269- thresholds = {
270- "gene_expression" : self .hyperparameters ["expression_var_threshold" ][output .dataset_name ],
271- "mutations" : self .hyperparameters ["mutation_var_threshold" ][output .dataset_name ],
272- "copy_number_variation_gistic" : self .hyperparameters ["cnv_var_threshold" ][output .dataset_name ],
273- }
274249 for view in self .cell_line_views :
275- selector = VarianceThreshold (thresholds [view ])
276- cell_line_input .fit_transform_features (
277- train_ids = np .unique (output .cell_line_ids ), transformer = selector , view = view
250+ cell_line_input = select_features_for_view (
251+ view = view ,
252+ cell_line_input = cell_line_input ,
253+ output = output ,
278254 )
279255 self .gene_expression_features = cell_line_input .meta_info ["gene_expression" ]
280256 self .mutations_features = cell_line_input .meta_info ["mutations" ]
0 commit comments