Skip to content

Commit ba76240

Browse files
Irina NicolaeIrina Nicolae
authored andcommitted
Remove divide by 0 warning in virtual adversarial method
(cherry picked from commit 03d90f7)
1 parent 034b45d commit ba76240

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

art/attacks/virtual_adversarial.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def generate(self, x, **kwargs):
7070
x_adv = np.copy(x)
7171
dims = list(x.shape[1:])
7272
preds = self.classifier.predict(x_adv, logits=False)
73+
tol = 1e-10
7374

7475
for ind, val in enumerate(x_adv):
7576
d = np.random.randn(*dims)
@@ -88,7 +89,7 @@ def generate(self, x, **kwargs):
8889
x[...] += self.finite_diff
8990
preds_new = self.classifier.predict((val + d)[None, ...], logits=False)
9091
kl_div2 = entropy(preds[ind], preds_new[0])
91-
d_new[array_iter.multi_index] = (kl_div2 - kl_div1) / self.finite_diff
92+
d_new[array_iter.multi_index] = (kl_div2 - kl_div1) / (self.finite_diff + tol)
9293
x[...] -= self.finite_diff
9394
d = d_new
9495

@@ -108,11 +109,11 @@ def _normalize(x):
108109
:return: The normalized version of `x`.
109110
:rtype: `np.ndarray`
110111
"""
111-
tol = 1e-12
112+
tol = 1e-10
112113
dims = x.shape
113114

114115
x = x.flatten()
115-
inverse = (np.sum(x**2) + np.sqrt(tol)) ** -.5
116+
inverse = (np.sum(x**2) + tol) ** -.5
116117
x = x * inverse
117118
x = np.reshape(x, dims)
118119

0 commit comments

Comments
 (0)