2424from __future__ import absolute_import , division , print_function , unicode_literals
2525
2626from abc import ABC , abstractmethod
27-
2827from typing import Optional , Union , TYPE_CHECKING
2928import random
3029
@@ -92,7 +91,7 @@ def _predict_classifier(self, x: np.ndarray, batch_size: int, training_mode: boo
9291 """
9392 raise NotImplementedError
9493
95- def predict (self , x : np .ndarray , batch_size : int = 128 , ** kwargs ) -> np .ndarray : # pylint: disable=W0613
94+ def predict (self , x : np .ndarray , batch_size : int = 128 , ** kwargs ) -> np .ndarray :
9695 """
9796 Performs cumulative predictions over every ablation location
9897
@@ -117,17 +116,21 @@ def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> np.ndarray:
117116 for ablation_start in range (ablate_over_range ):
118117 ablated_x = self .ablator .forward (np .copy (x ), column_pos = ablation_start )
119118 if ablation_start == 0 :
120- preds = self ._predict_classifier (ablated_x , batch_size = batch_size , training_mode = False )
119+ preds = self ._predict_classifier (ablated_x , batch_size = batch_size , training_mode = False , ** kwargs )
121120 else :
122- preds += self ._predict_classifier (ablated_x , batch_size = batch_size , training_mode = False )
121+ preds += self ._predict_classifier (ablated_x , batch_size = batch_size , training_mode = False , ** kwargs )
123122 elif self .ablation_type == "block" :
124123 for xcorner in range (rows_in_data ):
125124 for ycorner in range (columns_in_data ):
126125 ablated_x = self .ablator .forward (np .copy (x ), row_pos = xcorner , column_pos = ycorner )
127126 if ycorner == 0 and xcorner == 0 :
128- preds = self ._predict_classifier (ablated_x , batch_size = batch_size , training_mode = False )
127+ preds = self ._predict_classifier (
128+ ablated_x , batch_size = batch_size , training_mode = False , ** kwargs
129+ )
129130 else :
130- preds += self ._predict_classifier (ablated_x , batch_size = batch_size , training_mode = False )
131+ preds += self ._predict_classifier (
132+ ablated_x , batch_size = batch_size , training_mode = False , ** kwargs
133+ )
131134 return preds
132135
133136
0 commit comments