@@ -48,7 +48,12 @@ def __init__(self, data_handler, model_retriever, prepared_train_set=None, prepa
4848 logger .error (f"Error initializing RelativitiesCalculator: { e } " )
4949 self .train_set = None
5050 self .test_set = None
51-
51+
52+ def _predict_from_df (self , df ):
53+ preprocessed_data = self .model_retriever .predictor .preprocess (df )
54+ predictions_array = self .model_retriever .predictor ._clf .predict (preprocessed_data [0 ])
55+ return predictions_array
56+
5257 def compute_base_values (self ):
5358 logger .info ("Computing base values on initiation." )
5459 params = self .model_retriever .predictor .params
@@ -102,7 +107,7 @@ def initialize_baseline(self):
102107
103108 def calculate_baseline_prediction (self , sample_train_row ):
104109 logger .info ("Calculating baseline prediction" )
105- return self .model_retriever . predictor . predict (sample_train_row ). iloc [ 0 ] [0 ]
110+ return self ._predict_from_df (sample_train_row )[0 ]
106111
107112 def construct_relativities_df (self ):
108113 logger .info ("constructing relativites DF" )
@@ -129,6 +134,7 @@ def construct_relativities_interaction_df(self):
129134 def get_relativities_df (self ):
130135 """
131136 Computes and returns the relativities DataFrame for the model.
137+ (Optimized with batch prediction)
132138 Returns:
133139 pd.DataFrame: The relativities DataFrame.
134140 """
@@ -139,11 +145,13 @@ def get_relativities_df(self):
139145 self .relativities = {'base' : {'base' : baseline_prediction }}
140146 used_features = self .model_retriever .get_used_features ()
141147
148+ dfs_to_predict = []
149+ features_and_values = [] # To map results back
150+
142151 for feature in used_features :
143152 feature_type = self .model_retriever .features [feature ]['type' ]
144153 base_value = self .base_values [feature ]
145- self .relativities [feature ] = {base_value : 1.0 }
146- train_row_copy = sample_train_row .copy ()
154+ self .relativities [feature ] = {}
147155
148156 exposure_col = self .model_retriever .exposure_columns
149157 exposure_per_modality = self .train_set .groupby (feature )[exposure_col ].sum ()
@@ -153,8 +161,22 @@ def get_relativities_df(self):
153161 values_to_process .append (base_value )
154162
155163 for value in values_to_process :
164+ if value == base_value :
165+ self .relativities [feature ][value ] = 1.0
166+ continue
167+
168+ train_row_copy = sample_train_row .copy ()
156169 train_row_copy [feature ] = value
157- prediction = self .model_retriever .predictor .predict (train_row_copy ).iloc [0 ][0 ]
170+ dfs_to_predict .append (train_row_copy )
171+ features_and_values .append ((feature , value ))
172+
173+ if dfs_to_predict :
174+ logger .info (f"Predicting batch of { len (dfs_to_predict )} rows for relativities..." )
175+ batch_df = pd .concat (dfs_to_predict , ignore_index = True )
176+ batch_predictions = self ._predict_from_df (batch_df )
177+
178+ for i , (feature , value ) in enumerate (features_and_values ):
179+ prediction = batch_predictions [i ]
158180 relativity = prediction / baseline_prediction
159181 self .relativities [feature ][value ] = relativity
160182
@@ -165,6 +187,7 @@ def get_relativities_df(self):
165187 def get_relativities_interactions_df (self ):
166188 """
167189 Computes and returns the relativities DataFrame for the model.
190+ (Optimized with batch prediction)
168191 Returns:
169192 pd.DataFrame: The relativities DataFrame.
170193 """
@@ -174,43 +197,58 @@ def get_relativities_interactions_df(self):
174197
175198 self .relativities_interaction = {}
176199 interactions = self .model_retriever .get_interactions ()
200+
201+ dfs_to_predict = []
202+ features_and_values_list = [] # To map results back
177203
178204 for interaction in interactions :
179205 interaction_first = interaction [0 ]
180206 interaction_second = interaction [1 ]
181207
182208 base_value_first = self .base_values [interaction_first ]
183209 base_value_second = self .base_values [interaction_second ]
184- try :
185- self .relativities_interaction [interaction_first ][interaction_second ] = {base_value_first : {base_value_second : 1.0 }}
186- except KeyError :
187- self .relativities_interaction [interaction_first ] = {interaction_second : {base_value_first : {base_value_second : 1.0 }}}
188- train_row_copy = sample_train_row .copy ()
210+
211+ # Initialize the nested dictionary structure
212+ if interaction_first not in self .relativities_interaction :
213+ self .relativities_interaction [interaction_first ] = {}
214+ if interaction_second not in self .relativities_interaction [interaction_first ]:
215+ self .relativities_interaction [interaction_first ][interaction_second ] = {}
216+ if base_value_first not in self .relativities_interaction [interaction_first ][interaction_second ]:
217+ self .relativities_interaction [interaction_first ][interaction_second ][base_value_first ] = {}
218+
219+ # Set base relativity
220+ self .relativities_interaction [interaction_first ][interaction_second ][base_value_first ][base_value_second ] = 1.0
189221
190222 type_first = self .variable_types .get (interaction_first )
191223 type_second = self .variable_types .get (interaction_second )
192224
193- if type_first == 'CATEGORICAL' :
194- values_to_process_first = self .modalities [interaction_first ]
195- else :
196- values_to_process_first = [base_value_first ]
197-
198- if type_second == 'CATEGORICAL' :
199- values_to_process_second = self .modalities [interaction_second ]
200- else :
201- values_to_process_second = [base_value_second ]
202-
225+ values_to_process_first = self .modalities [interaction_first ] if type_first == 'CATEGORICAL' else [base_value_first ]
226+ values_to_process_second = self .modalities [interaction_second ] if type_second == 'CATEGORICAL' else [base_value_second ]
203227
204228 for value_first in values_to_process_first :
205229 for value_second in values_to_process_second :
230+ if value_first == base_value_first and value_second == base_value_second :
231+ continue # Skip base case, already set to 1.0
232+
233+ train_row_copy = sample_train_row .copy ()
206234 train_row_copy [interaction_first ] = value_first
207235 train_row_copy [interaction_second ] = value_second
208- prediction = self .model_retriever .predictor .predict (train_row_copy ).iloc [0 ][0 ]
209- relativity = prediction / baseline_prediction
210- try :
211- self .relativities_interaction [interaction_first ][interaction_second ][value_first ][value_second ] = relativity
212- except KeyError :
213- self .relativities_interaction [interaction_first ][interaction_second ][value_first ] = {value_second : relativity }
236+ dfs_to_predict .append (train_row_copy )
237+ features_and_values_list .append ((interaction_first , interaction_second , value_first , value_second ))
238+
239+ # Predict on the entire batch at once
240+ if dfs_to_predict :
241+ logger .info (f"Predicting batch of { len (dfs_to_predict )} rows for interactions..." )
242+ batch_df = pd .concat (dfs_to_predict , ignore_index = True )
243+ batch_predictions = self ._predict_from_df (batch_df )
244+
245+ # Map results back
246+ for i , (f1 , f2 , v1 , v2 ) in enumerate (features_and_values_list ):
247+ prediction = batch_predictions [i ]
248+ relativity = prediction / baseline_prediction
249+ if v1 not in self .relativities_interaction [f1 ][f2 ]:
250+ self .relativities_interaction [f1 ][f2 ][v1 ] = {}
251+ self .relativities_interaction [f1 ][f2 ][v1 ][v2 ] = relativity
214252
215253 relativities_interaction_df = self .construct_relativities_interaction_df ()
216254 logger .info ("Relativities DataFrame computed" )
@@ -245,7 +283,7 @@ def prepare_dataset(self, dataset_type='train'):
245283 else :
246284 raise ValueError ("dataset_type must be either 'train' or 'test'" )
247285
248- predicted = self .model_retriever . predictor . predict (dataset )
286+ predicted = self ._predict_from_df (dataset )
249287 dataset ['predicted' ] = predicted
250288 dataset ['weight' ] = 1 if self .model_retriever .exposure_columns is None else dataset [self .model_retriever .exposure_columns ]
251289
@@ -311,12 +349,9 @@ def weighted_mean(x):
311349 if other_feature != feature :
312350 feature_df [other_feature ] = self .base_values [other_feature ]
313351
314- logger .debug ("predictions" )
315- logger .debug (feature_df )
316- predictions = self .model_retriever .predictor .predict (feature_df )
317- logger .debug (predictions )
352+ predictions = self ._predict_from_df (feature_df )
318353 base_data [feature ] = pd .DataFrame ({
319- f'base_{ feature } ' : predictions [ 'prediction' ] ,
354+ f'base_{ feature } ' : predictions ,
320355 feature : feature_df [feature ]
321356 })
322357
0 commit comments