Skip to content

Commit 0319e0a

Browse files
committed
Removing fix for spurious YOLO predictions as generates nans due to alpha in python yolo libraries. Adding test for adversarial patch
Signed-off-by: Kieran Fraser <[email protected]>
1 parent b8f7c74 commit 0319e0a

File tree

2 files changed

+50
-3
lines changed

2 files changed

+50
-3
lines changed

art/estimators/object_detection/pytorch_yolo.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,8 +378,6 @@ def _get_losses(
378378
:return: Loss gradients of the same shape as `x`.
379379
"""
380380
self._model.train()
381-
self.set_batchnorm(train=False)
382-
self.set_dropout(train=False)
383381

384382
# Apply preprocessing and convert to tensors
385383
x_preprocessed, y_preprocessed = self._preprocess_and_convert_inputs(x=x, y=y, fit=False, no_grad=False)

tests/estimators/object_detection/test_pytorch_yolo.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def test_compute_loss(art_warning, get_pytorch_yolo):
367367
# Compute loss
368368
loss = object_detector.compute_loss(x=x_test, y=y_test)
369369

370-
assert pytest.approx(11.20741, abs=0.9) == float(loss)
370+
assert pytest.approx(11.20741, abs=1.5) == float(loss)
371371

372372
except ARTTestException as e:
373373
art_warning(e)
@@ -386,3 +386,52 @@ def test_pgd(art_warning, get_pytorch_yolo):
386386

387387
except ARTTestException as e:
388388
art_warning(e)
389+
390+
@pytest.mark.only_with_platform("pytorch")
391+
def test_patch(art_warning, get_pytorch_yolo):
392+
try:
393+
394+
from art.attacks.evasion import AdversarialPatchPyTorch
395+
396+
rotation_max=0.0
397+
scale_min=0.1
398+
scale_max=0.3
399+
distortion_scale_max=0.0
400+
learning_rate=1.99
401+
max_iter=2
402+
batch_size=16
403+
patch_shape=(3, 5, 5)
404+
patch_type="circle"
405+
optimizer="pgd"
406+
407+
object_detector, x_test, y_test = get_pytorch_yolo
408+
409+
ap = AdversarialPatchPyTorch(estimator=object_detector, rotation_max=rotation_max,
410+
scale_min=scale_min, scale_max=scale_max, optimizer=optimizer, distortion_scale_max=distortion_scale_max,
411+
learning_rate=learning_rate, max_iter=max_iter, batch_size=batch_size,
412+
patch_shape=patch_shape, patch_type=patch_type, verbose=True, targeted=False)
413+
414+
_, _ = ap.generate(x=x_test, y=y_test)
415+
416+
patched_images = ap.apply_patch(x_test, scale=0.4)
417+
result = object_detector.predict(patched_images)
418+
419+
assert result[0]["scores"].shape == (10647,)
420+
expected_detection_scores = np.asarray(
421+
[
422+
4.3653536e-08,
423+
3.3987994e-06,
424+
2.5681820e-06,
425+
3.9782722e-06,
426+
2.1766680e-05,
427+
2.6138965e-05,
428+
6.3377396e-05,
429+
7.6248516e-06,
430+
4.3447722e-06,
431+
3.6515078e-06,
432+
]
433+
)
434+
np.testing.assert_raises(AssertionError, np.testing.assert_array_almost_equal, result[0]["scores"][:10], expected_detection_scores, 6)
435+
436+
except ARTTestException as e:
437+
art_warning(e)

0 commit comments

Comments
 (0)