Skip to content

Commit 6be8850

Browse files
premkiran-o7beat-buesser
authored andcommitted
Updated binom_test to binomtest in randomized_smoothing.py
Signed-off-by: Prem Kiran Laknaboina <[email protected]>
1 parent ea7bc9e commit 6be8850

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

art/estimators/certification/randomized_smoothing/randomized_smoothing.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def predict(self, x: np.ndarray, batch_size: int = 128, verbose: bool = False, *
8585
:type is_abstain: `boolean`
8686
:return: Array of predictions of shape `(nb_inputs, nb_classes)`.
8787
"""
88-
from scipy.stats import binom_test
88+
from scipy.stats import binomtest
8989

9090
is_abstain = kwargs.get("is_abstain")
9191
if is_abstain is not None and not isinstance(is_abstain, bool): # pragma: no cover
@@ -100,12 +100,15 @@ def predict(self, x: np.ndarray, batch_size: int = 128, verbose: bool = False, *
100100
# get class counts
101101
counts_pred = self._prediction_counts(x_i, batch_size=batch_size)
102102
top = counts_pred.argsort()[::-1]
103-
count1 = np.max(counts_pred)
104-
count2 = counts_pred[top[1]]
103+
# Conersion to int
104+
count1 = int(np.max(counts_pred))
105+
count2 = int(counts_pred[top[1]])
105106

106107
# predict or abstain
107108
smooth_prediction = np.zeros(counts_pred.shape)
108-
if (not is_abstain) or (binom_test(count1, count1 + count2, p=0.5) <= self.alpha):
109+
#Get p value from BinomTestResult object
110+
p_value = binomtest(count1, count1 + count2, p=0.5).pvalue
111+
if (not is_abstain) or (p_value <= self.alpha):
109112
smooth_prediction[np.argmax(counts_pred)] = 1
110113
elif is_abstain:
111114
n_abstained += 1

0 commit comments

Comments
 (0)