@@ -91,12 +91,13 @@ def _predict_classifier(self, x: np.ndarray, batch_size: int, training_mode: boo
9191 """
9292 raise NotImplementedError
9393
94- def predict (self , x : np .ndarray , batch_size : int = 128 , ** kwargs ) -> np .ndarray :
94+ def predict (self , x : np .ndarray , batch_size : int = 128 , training_mode : bool = False , ** kwargs ) -> np .ndarray :
9595 """
9696 Performs cumulative predictions over every ablation location
9797
9898 :param x: Unablated image
9999 :param batch_size: the batch size for the prediction
100+ :param training_mode: if to run the classifier in training mode
100101 :return: cumulative predictions after sweeping over all the ablation configurations.
101102 """
102103 if self ._channels_first :
@@ -116,20 +117,24 @@ def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> np.ndarray:
116117 for ablation_start in range (ablate_over_range ):
117118 ablated_x = self .ablator .forward (np .copy (x ), column_pos = ablation_start )
118119 if ablation_start == 0 :
119- preds = self ._predict_classifier (ablated_x , batch_size = batch_size , training_mode = False , ** kwargs )
120+ preds = self ._predict_classifier (
121+ ablated_x , batch_size = batch_size , training_mode = training_mode , ** kwargs
122+ )
120123 else :
121- preds += self ._predict_classifier (ablated_x , batch_size = batch_size , training_mode = False , ** kwargs )
124+ preds += self ._predict_classifier (
125+ ablated_x , batch_size = batch_size , training_mode = training_mode , ** kwargs
126+ )
122127 elif self .ablation_type == "block" :
123128 for xcorner in range (rows_in_data ):
124129 for ycorner in range (columns_in_data ):
125130 ablated_x = self .ablator .forward (np .copy (x ), row_pos = xcorner , column_pos = ycorner )
126131 if ycorner == 0 and xcorner == 0 :
127132 preds = self ._predict_classifier (
128- ablated_x , batch_size = batch_size , training_mode = False , ** kwargs
133+ ablated_x , batch_size = batch_size , training_mode = training_mode , ** kwargs
129134 )
130135 else :
131136 preds += self ._predict_classifier (
132- ablated_x , batch_size = batch_size , training_mode = False , ** kwargs
137+ ablated_x , batch_size = batch_size , training_mode = training_mode , ** kwargs
133138 )
134139 return preds
135140
0 commit comments