Skip to content

Commit b430251

Browse files
committed
fix kwargs bug
Signed-off-by: GiulioZizzo <[email protected]>
1 parent 1c3ace1 commit b430251

File tree

3 files changed

+20
-9
lines changed

3 files changed

+20
-9
lines changed

art/estimators/certification/derandomized_smoothing/derandomized_smoothing.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,13 @@ def _predict_classifier(self, x: np.ndarray, batch_size: int, training_mode: boo
9191
"""
9292
raise NotImplementedError
9393

94-
def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> np.ndarray:
94+
def predict(self, x: np.ndarray, batch_size: int = 128, training_mode: bool = False, **kwargs) -> np.ndarray:
9595
"""
9696
Performs cumulative predictions over every ablation location
9797
9898
:param x: Unablated image
9999
:param batch_size: the batch size for the prediction
100+
:param training_mode: if to run the classifier in training mode
100101
:return: cumulative predictions after sweeping over all the ablation configurations.
101102
"""
102103
if self._channels_first:
@@ -116,20 +117,24 @@ def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> np.ndarray:
116117
for ablation_start in range(ablate_over_range):
117118
ablated_x = self.ablator.forward(np.copy(x), column_pos=ablation_start)
118119
if ablation_start == 0:
119-
preds = self._predict_classifier(ablated_x, batch_size=batch_size, training_mode=False, **kwargs)
120+
preds = self._predict_classifier(
121+
ablated_x, batch_size=batch_size, training_mode=training_mode, **kwargs
122+
)
120123
else:
121-
preds += self._predict_classifier(ablated_x, batch_size=batch_size, training_mode=False, **kwargs)
124+
preds += self._predict_classifier(
125+
ablated_x, batch_size=batch_size, training_mode=training_mode, **kwargs
126+
)
122127
elif self.ablation_type == "block":
123128
for xcorner in range(rows_in_data):
124129
for ycorner in range(columns_in_data):
125130
ablated_x = self.ablator.forward(np.copy(x), row_pos=xcorner, column_pos=ycorner)
126131
if ycorner == 0 and xcorner == 0:
127132
preds = self._predict_classifier(
128-
ablated_x, batch_size=batch_size, training_mode=False, **kwargs
133+
ablated_x, batch_size=batch_size, training_mode=training_mode, **kwargs
129134
)
130135
else:
131136
preds += self._predict_classifier(
132-
ablated_x, batch_size=batch_size, training_mode=False, **kwargs
137+
ablated_x, batch_size=batch_size, training_mode=training_mode, **kwargs
133138
)
134139
return preds
135140

art/estimators/certification/derandomized_smoothing/pytorch.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,15 +131,18 @@ def _predict_classifier(self, x: np.ndarray, batch_size: int, training_mode: boo
131131
(torch.nn.functional.softmax(torch.from_numpy(outputs), dim=1) >= self.threshold).type(torch.int)
132132
)
133133

134-
def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> np.ndarray: # type: ignore
134+
def predict(
135+
self, x: np.ndarray, batch_size: int = 128, training_mode: bool = False, **kwargs
136+
) -> np.ndarray: # type: ignore
135137
"""
136138
Perform prediction of the given classifier for a batch of inputs, taking an expectation over transformations.
137139
138140
:param x: Input samples.
139141
:param batch_size: Batch size.
142+
:param training_mode: if to run the classifier in training mode
140143
:return: Array of predictions of shape `(nb_inputs, nb_classes)`.
141144
"""
142-
return DeRandomizedSmoothingMixin.predict(self, x, batch_size=batch_size, training_mode=False, **kwargs)
145+
return DeRandomizedSmoothingMixin.predict(self, x, batch_size=batch_size, training_mode=training_mode, **kwargs)
143146

144147
def _fit_classifier(self, x: np.ndarray, y: np.ndarray, batch_size: int, nb_epochs: int, **kwargs) -> None:
145148
x = x.astype(ART_NUMPY_DTYPE)

art/estimators/certification/derandomized_smoothing/tensorflow.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,15 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in
170170
if scheduler is not None:
171171
scheduler(epoch)
172172

173-
def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> np.ndarray: # type: ignore
173+
def predict(
174+
self, x: np.ndarray, batch_size: int = 128, training_mode: bool = False, **kwargs
175+
) -> np.ndarray: # type: ignore
174176
"""
175177
Perform prediction of the given classifier for a batch of inputs
176178
177179
:param x: Input samples.
178180
:param batch_size: Batch size.
181+
:param training_mode: if to run the classifier in training mode
179182
:return: Array of predictions of shape `(nb_inputs, nb_classes)`.
180183
"""
181-
return DeRandomizedSmoothingMixin.predict(self, x, batch_size=batch_size, training_mode=False, **kwargs)
184+
return DeRandomizedSmoothingMixin.predict(self, x, batch_size=batch_size, training_mode=training_mode, **kwargs)

0 commit comments

Comments
 (0)