Skip to content

Commit 7e46923

Browse files
committed
Formatting fix
Signed-off-by: Kieran Fraser <[email protected]>
1 parent 0319e0a commit 7e46923

File tree

2 files changed

+47
-29
lines changed

2 files changed

+47
-29
lines changed

art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -601,12 +601,14 @@ def __getitem__(self, idx):
601601
target = target.to(self.estimator.device)
602602
else:
603603
targets = []
604-
for idx in range(target['boxes'].shape[0]):
605-
targets.append({
606-
'boxes': target['boxes'][idx].to(self.estimator.device),
607-
'labels': target['labels'][idx].to(self.estimator.device),
608-
'scores': target['scores'][idx].to(self.estimator.device),
609-
})
604+
for idx in range(target["boxes"].shape[0]):
605+
targets.append(
606+
{
607+
"boxes": target["boxes"][idx].to(self.estimator.device),
608+
"labels": target["labels"][idx].to(self.estimator.device),
609+
"scores": target["scores"][idx].to(self.estimator.device),
610+
}
611+
)
610612
_ = self._train_step(images=images, target=targets, mask=None)
611613
else:
612614
for images, target, mask_i in data_loader:
@@ -615,12 +617,14 @@ def __getitem__(self, idx):
615617
target = target.to(self.estimator.device)
616618
else:
617619
targets = []
618-
for idx in range(target['boxes'].shape[0]):
619-
targets.append({
620-
'boxes': target['boxes'][idx].to(self.estimator.device),
621-
'labels': target['labels'][idx].to(self.estimator.device),
622-
'scores': target['scores'][idx].to(self.estimator.device),
623-
})
620+
for idx in range(target["boxes"].shape[0]):
621+
targets.append(
622+
{
623+
"boxes": target["boxes"][idx].to(self.estimator.device),
624+
"labels": target["labels"][idx].to(self.estimator.device),
625+
"scores": target["scores"][idx].to(self.estimator.device),
626+
}
627+
)
624628
mask_i = mask_i.to(self.estimator.device)
625629
_ = self._train_step(images=images, target=targets, mask=mask_i)
626630

tests/estimators/object_detection/test_pytorch_yolo.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -387,29 +387,41 @@ def test_pgd(art_warning, get_pytorch_yolo):
387387
except ARTTestException as e:
388388
art_warning(e)
389389

390+
390391
@pytest.mark.only_with_platform("pytorch")
391392
def test_patch(art_warning, get_pytorch_yolo):
392393
try:
393-
394+
394395
from art.attacks.evasion import AdversarialPatchPyTorch
395396

396-
rotation_max=0.0
397-
scale_min=0.1
398-
scale_max=0.3
399-
distortion_scale_max=0.0
400-
learning_rate=1.99
401-
max_iter=2
402-
batch_size=16
403-
patch_shape=(3, 5, 5)
404-
patch_type="circle"
405-
optimizer="pgd"
397+
rotation_max = 0.0
398+
scale_min = 0.1
399+
scale_max = 0.3
400+
distortion_scale_max = 0.0
401+
learning_rate = 1.99
402+
max_iter = 2
403+
batch_size = 16
404+
patch_shape = (3, 5, 5)
405+
patch_type = "circle"
406+
optimizer = "pgd"
406407

407408
object_detector, x_test, y_test = get_pytorch_yolo
408409

409-
ap = AdversarialPatchPyTorch(estimator=object_detector, rotation_max=rotation_max,
410-
scale_min=scale_min, scale_max=scale_max, optimizer=optimizer, distortion_scale_max=distortion_scale_max,
411-
learning_rate=learning_rate, max_iter=max_iter, batch_size=batch_size,
412-
patch_shape=patch_shape, patch_type=patch_type, verbose=True, targeted=False)
410+
ap = AdversarialPatchPyTorch(
411+
estimator=object_detector,
412+
rotation_max=rotation_max,
413+
scale_min=scale_min,
414+
scale_max=scale_max,
415+
optimizer=optimizer,
416+
distortion_scale_max=distortion_scale_max,
417+
learning_rate=learning_rate,
418+
max_iter=max_iter,
419+
batch_size=batch_size,
420+
patch_shape=patch_shape,
421+
patch_type=patch_type,
422+
verbose=True,
423+
targeted=False,
424+
)
413425

414426
_, _ = ap.generate(x=x_test, y=y_test)
415427

@@ -431,7 +443,9 @@ def test_patch(art_warning, get_pytorch_yolo):
431443
3.6515078e-06,
432444
]
433445
)
434-
np.testing.assert_raises(AssertionError, np.testing.assert_array_almost_equal, result[0]["scores"][:10], expected_detection_scores, 6)
446+
np.testing.assert_raises(
447+
AssertionError, np.testing.assert_array_almost_equal, result[0]["scores"][:10], expected_detection_scores, 6
448+
)
435449

436450
except ARTTestException as e:
437-
art_warning(e)
451+
art_warning(e)

0 commit comments

Comments
 (0)