3232
3333
3434@pytest .mark .framework_agnostic
35+ @pytest .mark .parametrize ("class_source" , [None , 0 ])
3536@pytest .mark .parametrize ("percent_poison" , [0.3 , 1.0 ])
3637@pytest .mark .parametrize ("channels_first" , [True , False ])
37- def test_poison_single_bd (art_warning , image_batch , percent_poison , channels_first ):
38+ def test_poison_single_bd (art_warning , image_batch , class_source , percent_poison , channels_first ):
3839 x , y = image_batch
3940 backdoor = PoisoningAttackBackdoor (add_single_bd )
4041
4142 try :
4243 attack = BadDetRegionalMisclassificationAttack (
4344 backdoor = backdoor ,
44- class_source = 0 ,
45+ class_source = class_source ,
4546 class_target = 1 ,
4647 percent_poison = percent_poison ,
4748 channels_first = channels_first ,
@@ -56,16 +57,17 @@ def test_poison_single_bd(art_warning, image_batch, percent_poison, channels_fir
5657
5758
5859@pytest .mark .framework_agnostic
60+ @pytest .mark .parametrize ("class_source" , [None , 0 ])
5961@pytest .mark .parametrize ("percent_poison" , [0.3 , 1.0 ])
6062@pytest .mark .parametrize ("channels_first" , [True , False ])
61- def test_poison_pattern_bd (art_warning , image_batch , percent_poison , channels_first ):
63+ def test_poison_pattern_bd (art_warning , image_batch , class_source , percent_poison , channels_first ):
6264 x , y = image_batch
6365 backdoor = PoisoningAttackBackdoor (add_pattern_bd )
6466
6567 try :
6668 attack = BadDetRegionalMisclassificationAttack (
6769 backdoor = backdoor ,
68- class_source = 0 ,
70+ class_source = class_source ,
6971 class_target = 1 ,
7072 percent_poison = percent_poison ,
7173 channels_first = channels_first ,
@@ -80,9 +82,10 @@ def test_poison_pattern_bd(art_warning, image_batch, percent_poison, channels_fi
8082
8183
8284@pytest .mark .framework_agnostic
85+ @pytest .mark .parametrize ("class_source" , [None , 0 ])
8386@pytest .mark .parametrize ("percent_poison" , [0.3 , 1.0 ])
8487@pytest .mark .parametrize ("channels_first" , [True , False ])
85- def test_poison_image (art_warning , image_batch , percent_poison , channels_first ):
88+ def test_poison_image (art_warning , image_batch , class_source , percent_poison , channels_first ):
8689 x , y = image_batch
8790
8891 file_path = os .path .join (os .getcwd (), "utils/data/backdoors/alert.png" )
@@ -95,7 +98,7 @@ def perturbation(x):
9598 try :
9699 attack = BadDetRegionalMisclassificationAttack (
97100 backdoor = backdoor ,
98- class_source = 0 ,
101+ class_source = class_source ,
99102 class_target = 1 ,
100103 percent_poison = percent_poison ,
101104 channels_first = channels_first ,
0 commit comments