Skip to content

Commit 3607185

Browse files
committed
update bad det tests with new RMA option
Signed-off-by: Farhan Ahmed <[email protected]>
1 parent cf06d2c commit 3607185

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

tests/attacks/poison/bad_det/test_bad_det_rma.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,17 @@
3232

3333

3434
@pytest.mark.framework_agnostic
35+
@pytest.mark.parametrize("class_source", [None, 0])
3536
@pytest.mark.parametrize("percent_poison", [0.3, 1.0])
3637
@pytest.mark.parametrize("channels_first", [True, False])
37-
def test_poison_single_bd(art_warning, image_batch, percent_poison, channels_first):
38+
def test_poison_single_bd(art_warning, image_batch, class_source, percent_poison, channels_first):
3839
x, y = image_batch
3940
backdoor = PoisoningAttackBackdoor(add_single_bd)
4041

4142
try:
4243
attack = BadDetRegionalMisclassificationAttack(
4344
backdoor=backdoor,
44-
class_source=0,
45+
class_source=class_source,
4546
class_target=1,
4647
percent_poison=percent_poison,
4748
channels_first=channels_first,
@@ -56,16 +57,17 @@ def test_poison_single_bd(art_warning, image_batch, percent_poison, channels_fir
5657

5758

5859
@pytest.mark.framework_agnostic
60+
@pytest.mark.parametrize("class_source", [None, 0])
5961
@pytest.mark.parametrize("percent_poison", [0.3, 1.0])
6062
@pytest.mark.parametrize("channels_first", [True, False])
61-
def test_poison_pattern_bd(art_warning, image_batch, percent_poison, channels_first):
63+
def test_poison_pattern_bd(art_warning, image_batch, class_source, percent_poison, channels_first):
6264
x, y = image_batch
6365
backdoor = PoisoningAttackBackdoor(add_pattern_bd)
6466

6567
try:
6668
attack = BadDetRegionalMisclassificationAttack(
6769
backdoor=backdoor,
68-
class_source=0,
70+
class_source=class_source,
6971
class_target=1,
7072
percent_poison=percent_poison,
7173
channels_first=channels_first,
@@ -80,9 +82,10 @@ def test_poison_pattern_bd(art_warning, image_batch, percent_poison, channels_fi
8082

8183

8284
@pytest.mark.framework_agnostic
85+
@pytest.mark.parametrize("class_source", [None, 0])
8386
@pytest.mark.parametrize("percent_poison", [0.3, 1.0])
8487
@pytest.mark.parametrize("channels_first", [True, False])
85-
def test_poison_image(art_warning, image_batch, percent_poison, channels_first):
88+
def test_poison_image(art_warning, image_batch, class_source, percent_poison, channels_first):
8689
x, y = image_batch
8790

8891
file_path = os.path.join(os.getcwd(), "utils/data/backdoors/alert.png")
@@ -95,7 +98,7 @@ def perturbation(x):
9598
try:
9699
attack = BadDetRegionalMisclassificationAttack(
97100
backdoor=backdoor,
98-
class_source=0,
101+
class_source=class_source,
99102
class_target=1,
100103
percent_poison=percent_poison,
101104
channels_first=channels_first,

0 commit comments

Comments
 (0)