Skip to content

Commit 1c3ace1

Browse files
committed
final review comments
Signed-off-by: GiulioZizzo <[email protected]>
1 parent 86930a6 commit 1c3ace1

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

art/estimators/certification/derandomized_smoothing/derandomized_smoothing.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from __future__ import absolute_import, division, print_function, unicode_literals
2525

2626
from abc import ABC, abstractmethod
27-
2827
from typing import Optional, Union, TYPE_CHECKING
2928
import random
3029

@@ -92,7 +91,7 @@ def _predict_classifier(self, x: np.ndarray, batch_size: int, training_mode: boo
9291
"""
9392
raise NotImplementedError
9493

95-
def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> np.ndarray: # pylint: disable=W0613
94+
def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> np.ndarray:
9695
"""
9796
Performs cumulative predictions over every ablation location
9897
@@ -117,17 +116,21 @@ def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> np.ndarray:
117116
for ablation_start in range(ablate_over_range):
118117
ablated_x = self.ablator.forward(np.copy(x), column_pos=ablation_start)
119118
if ablation_start == 0:
120-
preds = self._predict_classifier(ablated_x, batch_size=batch_size, training_mode=False)
119+
preds = self._predict_classifier(ablated_x, batch_size=batch_size, training_mode=False, **kwargs)
121120
else:
122-
preds += self._predict_classifier(ablated_x, batch_size=batch_size, training_mode=False)
121+
preds += self._predict_classifier(ablated_x, batch_size=batch_size, training_mode=False, **kwargs)
123122
elif self.ablation_type == "block":
124123
for xcorner in range(rows_in_data):
125124
for ycorner in range(columns_in_data):
126125
ablated_x = self.ablator.forward(np.copy(x), row_pos=xcorner, column_pos=ycorner)
127126
if ycorner == 0 and xcorner == 0:
128-
preds = self._predict_classifier(ablated_x, batch_size=batch_size, training_mode=False)
127+
preds = self._predict_classifier(
128+
ablated_x, batch_size=batch_size, training_mode=False, **kwargs
129+
)
129130
else:
130-
preds += self._predict_classifier(ablated_x, batch_size=batch_size, training_mode=False)
131+
preds += self._predict_classifier(
132+
ablated_x, batch_size=batch_size, training_mode=False, **kwargs
133+
)
131134
return preds
132135

133136

0 commit comments

Comments
 (0)