Skip to content

Commit cfc32ed

Browse files
authored
Merge pull request #877 from Trusted-AI/development_issue_870
Define dtype in SquareAttack with norm=2
2 parents 9e2f822 + f31f1a4 commit cfc32ed

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

art/attacks/evasion/square_attack.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def _get_perturbation(h):
259259

260260
return delta
261261

262-
delta_init = np.zeros(x_robust.shape)
262+
delta_init = np.zeros(x_robust.shape, dtype=ART_NUMPY_DTYPE)
263263

264264
height_start = 0
265265
for _ in range(n_tiles):

tests/attacks/evasion/test_square_attack.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,26 @@ def fix_get_mnist_subset(get_mnist_dataset):
3939

4040

4141
@pytest.mark.framework_agnostic
42-
def test_generate(art_warning, fix_get_mnist_subset, image_dl_estimator_for_attack):
42+
@pytest.mark.parametrize("norm", [2, "inf"])
43+
def test_generate(art_warning, fix_get_mnist_subset, image_dl_estimator_for_attack, norm):
4344
try:
4445
classifier = image_dl_estimator_for_attack(SquareAttack)
4546

46-
attack = SquareAttack(estimator=classifier, norm=np.inf, max_iter=5, eps=0.3, p_init=0.8, nb_restarts=1)
47+
attack = SquareAttack(estimator=classifier, norm=norm, max_iter=5, eps=0.3, p_init=0.8, nb_restarts=1)
4748

4849
(x_train_mnist, y_train_mnist, x_test_mnist, y_test_mnist) = fix_get_mnist_subset
4950

5051
x_train_mnist_adv = attack.generate(x=x_train_mnist, y=y_train_mnist)
5152

52-
assert np.mean(np.abs(x_train_mnist_adv - x_train_mnist)) == pytest.approx(0.053533513, abs=0.025)
53-
assert np.max(np.abs(x_train_mnist_adv - x_train_mnist)) == pytest.approx(0.3, abs=0.05)
53+
if norm == "inf":
54+
expected_mean = 0.053533513
55+
expected_max = 0.3
56+
elif norm == 2:
57+
expected_mean = 0.00073682
58+
expected_max = 0.25
59+
60+
assert np.mean(np.abs(x_train_mnist_adv - x_train_mnist)) == pytest.approx(expected_mean, abs=0.025)
61+
assert np.max(np.abs(x_train_mnist_adv - x_train_mnist)) == pytest.approx(expected_max, abs=0.05)
5462
except ARTTestException as e:
5563
art_warning(e)
5664

0 commit comments

Comments
 (0)