Skip to content

Commit a89d004

Browse files
committed
fix unittest for cw
1 parent cea5a02 commit a89d004

File tree

2 files changed

+29
-24
lines changed

2 files changed

+29
-24
lines changed

art/attacks/carlini.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -274,10 +274,14 @@ def set_params(self, **kwargs):
274274
# Save attack-specific parameters
275275
super(CarliniL2Method, self).set_params(**kwargs)
276276

277-
if type(self.binary_search_steps) is not int or self.binary_search_steps <= 0:
278-
raise ValueError("The number of binary search steps must be a positive integer.")
277+
if type(self.binary_search_steps) is not int or self.binary_search_steps < 0:
278+
raise ValueError("The number of binary search steps must be a non-negative integer.")
279279

280-
if type(self.max_iter) is not int or self.max_iter <= 0:
281-
raise ValueError("The number of iterations must be a positive integer.")
280+
if type(self.max_iter) is not int or self.max_iter < 0:
281+
raise ValueError("The number of iterations must be a non-negative integer.")
282282

283283
return True
284+
285+
286+
287+

art/attacks/carlini_unittest.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def test_failure_attack(self):
8181
learning_rate=2e-2, initial_const=3, decay=1e-2)
8282
params = {'y': random_targets(y_test, tfc.nb_classes)}
8383
x_test_adv = cl2m.generate(x_test, **params)
84-
self.assertTrue((x_test_adv <= 1).all())
85-
self.assertTrue((x_test_adv >= 0).all())
84+
self.assertTrue((x_test_adv <= 1.0001 ).all())
85+
self.assertTrue((x_test_adv >= -0.0001 ).all())
8686
np.testing.assert_almost_equal(x_test, x_test_adv, 3)
8787

8888
def test_tfclassifier(self):
@@ -129,8 +129,9 @@ def test_tfclassifier(self):
129129
params = {'y': random_targets(y_test, tfc.nb_classes)}
130130
x_test_adv = cl2m.generate(x_test, **params)
131131
self.assertFalse((x_test == x_test_adv).all())
132-
self.assertTrue((x_test_adv <= 1).all())
133-
self.assertTrue((x_test_adv >= 0).all())
132+
#print(x_test_adv)
133+
self.assertTrue((x_test_adv <= 1.0001).all())
134+
self.assertTrue((x_test_adv >= -0.0001).all())
134135
target = np.argmax(params['y'], axis=1)
135136
y_pred_adv = np.argmax(tfc.predict(x_test_adv), axis=1)
136137
self.assertTrue((target == y_pred_adv).all())
@@ -141,8 +142,8 @@ def test_tfclassifier(self):
141142
params = {'y': random_targets(y_test, tfc.nb_classes)}
142143
x_test_adv = cl2m.generate(x_test, **params)
143144
self.assertFalse((x_test == x_test_adv).all())
144-
self.assertTrue((x_test_adv <= 1).all())
145-
self.assertTrue((x_test_adv >= 0).all())
145+
self.assertTrue((x_test_adv <= 1.0001).all())
146+
self.assertTrue((x_test_adv >= -0.0001).all())
146147
target = np.argmax(params['y'], axis=1)
147148
y_pred_adv = np.argmax(tfc.predict(x_test_adv), axis=1)
148149
self.assertTrue((target != y_pred_adv).all())
@@ -153,8 +154,8 @@ def test_tfclassifier(self):
153154
params = {}
154155
x_test_adv = cl2m.generate(x_test, **params)
155156
self.assertFalse((x_test == x_test_adv).all())
156-
self.assertTrue((x_test_adv <= 1).all())
157-
self.assertTrue((x_test_adv >= 0).all())
157+
self.assertTrue((x_test_adv <= 1.0001).all())
158+
self.assertTrue((x_test_adv >= -0.0001).all())
158159
y_pred = np.argmax(tfc.predict(x_test), axis=1)
159160
y_pred_adv = np.argmax(tfc.predict(x_test_adv), axis=1)
160161
self.assertTrue((y_pred != y_pred_adv).all())
@@ -194,8 +195,8 @@ def test_krclassifier(self):
194195
params = {'y': random_targets(y_test, krc.nb_classes)}
195196
x_test_adv = cl2m.generate(x_test, **params)
196197
self.assertFalse((x_test == x_test_adv).all())
197-
self.assertTrue((x_test_adv <= 1).all())
198-
self.assertTrue((x_test_adv >= 0).all())
198+
self.assertTrue((x_test_adv <= 1.0001).all())
199+
self.assertTrue((x_test_adv >= -0.0001).all())
199200
target = np.argmax(params['y'], axis=1)
200201
y_pred_adv = np.argmax(krc.predict(x_test_adv), axis=1)
201202
self.assertTrue((target == y_pred_adv).any())
@@ -206,8 +207,8 @@ def test_krclassifier(self):
206207
params = {'y': random_targets(y_test, krc.nb_classes)}
207208
x_test_adv = cl2m.generate(x_test, **params)
208209
self.assertFalse((x_test == x_test_adv).all())
209-
self.assertTrue((x_test_adv <= 1).all())
210-
self.assertTrue((x_test_adv >= 0).all())
210+
self.assertTrue((x_test_adv <= 1.0001).all())
211+
self.assertTrue((x_test_adv >= -0.0001).all())
211212
target = np.argmax(params['y'], axis=1)
212213
y_pred_adv = np.argmax(krc.predict(x_test_adv), axis=1)
213214
self.assertTrue((target != y_pred_adv).all())
@@ -218,8 +219,8 @@ def test_krclassifier(self):
218219
params = {}
219220
x_test_adv = cl2m.generate(x_test, **params)
220221
self.assertFalse((x_test == x_test_adv).all())
221-
self.assertTrue((x_test_adv <= 1).all())
222-
self.assertTrue((x_test_adv >= 0).all())
222+
self.assertTrue((x_test_adv <= 1.0001).all())
223+
self.assertTrue((x_test_adv >= -0.0001).all())
223224
y_pred = np.argmax(krc.predict(x_test), axis=1)
224225
y_pred_adv = np.argmax(krc.predict(x_test_adv), axis=1)
225226
self.assertTrue((y_pred != y_pred_adv).any())
@@ -255,8 +256,8 @@ def test_ptclassifier(self):
255256
params = {'y': random_targets(y_test, ptc.nb_classes)}
256257
x_test_adv = cl2m.generate(x_test, **params)
257258
self.assertFalse((x_test == x_test_adv).all())
258-
self.assertTrue((x_test_adv <= 1).all())
259-
self.assertTrue((x_test_adv >= 0).all())
259+
self.assertTrue((x_test_adv <= 1.0001).all())
260+
self.assertTrue((x_test_adv >= -0.0001).all())
260261
target = np.argmax(params['y'], axis=1)
261262
y_pred_adv = np.argmax(ptc.predict(x_test_adv), axis=1)
262263
self.assertTrue((target == y_pred_adv).any())
@@ -267,8 +268,8 @@ def test_ptclassifier(self):
267268
params = {'y': random_targets(y_test, ptc.nb_classes)}
268269
x_test_adv = cl2m.generate(x_test, **params)
269270
self.assertFalse((x_test == x_test_adv).all())
270-
self.assertTrue((x_test_adv <= 1).all())
271-
self.assertTrue((x_test_adv >= 0).all())
271+
self.assertTrue((x_test_adv <= 1.0001).all())
272+
self.assertTrue((x_test_adv >= -0.0001).all())
272273
target = np.argmax(params['y'], axis=1)
273274
y_pred_adv = np.argmax(ptc.predict(x_test_adv), axis=1)
274275
self.assertTrue((target != y_pred_adv).all())
@@ -279,8 +280,8 @@ def test_ptclassifier(self):
279280
params = {}
280281
x_test_adv = cl2m.generate(x_test, **params)
281282
self.assertFalse((x_test == x_test_adv).all())
282-
self.assertTrue((x_test_adv <= 1).all())
283-
self.assertTrue((x_test_adv >= 0).all())
283+
self.assertTrue((x_test_adv <= 1.0001).all())
284+
self.assertTrue((x_test_adv >= -0.0001).all())
284285
y_pred = np.argmax(ptc.predict(x_test), axis=1)
285286
y_pred_adv = np.argmax(ptc.predict(x_test_adv), axis=1)
286287
self.assertTrue((y_pred != y_pred_adv).any())

0 commit comments

Comments
 (0)