Skip to content

Commit 9a5ad34

Browse files
authored
Merge pull request #2110 from f4str/bad-dets-bug
BadDet Regional Misclassification Attack Bug Fix
2 parents 30e513c + 3607185 commit 9a5ad34

File tree

3 files changed

+352
-353
lines changed

3 files changed

+352
-353
lines changed

art/attacks/poisoning/bad_det/bad_det_rma.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from __future__ import absolute_import, division, print_function, unicode_literals
2424

2525
import logging
26-
from typing import Dict, List, Tuple
26+
from typing import Dict, List, Tuple, Optional
2727

2828
import numpy as np
2929
from tqdm.auto import tqdm
@@ -54,7 +54,7 @@ class BadDetRegionalMisclassificationAttack(PoisoningAttackObjectDetector):
5454
def __init__(
5555
self,
5656
backdoor: PoisoningAttackBackdoor,
57-
class_source: int = 0,
57+
class_source: Optional[int] = None,
5858
class_target: int = 1,
5959
percent_poison: float = 0.3,
6060
channels_first: bool = False,
@@ -64,7 +64,8 @@ def __init__(
6464
Creates a new BadDet Regional Misclassification Attack
6565
6666
:param backdoor: the backdoor chosen for this attack.
67-
:param class_source: The source class from which triggers were selected.
67+
:param class_source: The source class (optionally) from which triggers were selected. If no source is
68+
provided, then all classes will be poisoned.
6869
:param class_target: The target label to which the poisoned model needs to misclassify.
6970
:param percent_poison: The ratio of samples to poison in the source class, with range [0, 1].
7071
:param channels_first: Set channels first or last.
@@ -116,7 +117,7 @@ def poison( # pylint: disable=W0221
116117
target_dict = {k: v.copy() for k, v in y_i.items()}
117118
y_poison.append(target_dict)
118119

119-
if self.class_source in y_i["labels"]:
120+
if self.class_source is None or self.class_source in y_i["labels"]:
120121
source_indices.append(i)
121122

122123
# select indices of samples to poison
@@ -130,7 +131,7 @@ def poison( # pylint: disable=W0221
130131
labels = y_poison[i]["labels"]
131132

132133
for j, (box, label) in enumerate(zip(boxes, labels)):
133-
if label == self.class_source:
134+
if self.class_source is None or label == self.class_source:
134135
# extract the bounding box from the image
135136
x_1, y_1, x_2, y_2 = box.astype(int)
136137
bounding_box = image[y_1:y_2, x_1:x_2, :]

notebooks/poisoning_attack_bad_det.ipynb

Lines changed: 337 additions & 342 deletions
Large diffs are not rendered by default.

tests/attacks/poison/bad_det/test_bad_det_rma.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,17 @@
3232

3333

3434
@pytest.mark.framework_agnostic
35+
@pytest.mark.parametrize("class_source", [None, 0])
3536
@pytest.mark.parametrize("percent_poison", [0.3, 1.0])
3637
@pytest.mark.parametrize("channels_first", [True, False])
37-
def test_poison_single_bd(art_warning, image_batch, percent_poison, channels_first):
38+
def test_poison_single_bd(art_warning, image_batch, class_source, percent_poison, channels_first):
3839
x, y = image_batch
3940
backdoor = PoisoningAttackBackdoor(add_single_bd)
4041

4142
try:
4243
attack = BadDetRegionalMisclassificationAttack(
4344
backdoor=backdoor,
44-
class_source=0,
45+
class_source=class_source,
4546
class_target=1,
4647
percent_poison=percent_poison,
4748
channels_first=channels_first,
@@ -56,16 +57,17 @@ def test_poison_single_bd(art_warning, image_batch, percent_poison, channels_fir
5657

5758

5859
@pytest.mark.framework_agnostic
60+
@pytest.mark.parametrize("class_source", [None, 0])
5961
@pytest.mark.parametrize("percent_poison", [0.3, 1.0])
6062
@pytest.mark.parametrize("channels_first", [True, False])
61-
def test_poison_pattern_bd(art_warning, image_batch, percent_poison, channels_first):
63+
def test_poison_pattern_bd(art_warning, image_batch, class_source, percent_poison, channels_first):
6264
x, y = image_batch
6365
backdoor = PoisoningAttackBackdoor(add_pattern_bd)
6466

6567
try:
6668
attack = BadDetRegionalMisclassificationAttack(
6769
backdoor=backdoor,
68-
class_source=0,
70+
class_source=class_source,
6971
class_target=1,
7072
percent_poison=percent_poison,
7173
channels_first=channels_first,
@@ -80,9 +82,10 @@ def test_poison_pattern_bd(art_warning, image_batch, percent_poison, channels_fi
8082

8183

8284
@pytest.mark.framework_agnostic
85+
@pytest.mark.parametrize("class_source", [None, 0])
8386
@pytest.mark.parametrize("percent_poison", [0.3, 1.0])
8487
@pytest.mark.parametrize("channels_first", [True, False])
85-
def test_poison_image(art_warning, image_batch, percent_poison, channels_first):
88+
def test_poison_image(art_warning, image_batch, class_source, percent_poison, channels_first):
8689
x, y = image_batch
8790

8891
file_path = os.path.join(os.getcwd(), "utils/data/backdoors/alert.png")
@@ -95,7 +98,7 @@ def perturbation(x):
9598
try:
9699
attack = BadDetRegionalMisclassificationAttack(
97100
backdoor=backdoor,
98-
class_source=0,
101+
class_source=class_source,
99102
class_target=1,
100103
percent_poison=percent_poison,
101104
channels_first=channels_first,

0 commit comments

Comments
 (0)