2424from art .attacks .poisoning import HiddenTriggerBackdoor
2525from art .attacks .poisoning import PoisoningAttackBackdoor
2626from art .attacks .poisoning .perturbations import add_pattern_bd
27+ from art .estimators .classification .pytorch import PyTorchClassifier
2728
2829from tests .utils import ARTTestException
2930
@@ -36,12 +37,21 @@ def test_poison(art_warning, get_default_mnist_subset, image_dl_estimator):
3637 (x_train , y_train ), (_ , _ ) = get_default_mnist_subset
3738 classifier , _ = image_dl_estimator (functional = True )
3839
39- def mod (x ):
40- original_dtype = x .dtype
41- x = np .transpose (x , (0 , 2 , 3 , 1 )).astype (np .float32 )
42- x = add_pattern_bd (x )
43- x = np .transpose (x , (0 , 3 , 1 , 2 )).astype (np .float32 )
44- return x .astype (original_dtype )
40+ if isinstance (classifier , PyTorchClassifier ):
41+
42+ def mod (x ):
43+ original_dtype = x .dtype
44+ x = np .transpose (x , (0 , 2 , 3 , 1 )).astype (np .float32 )
45+ x = add_pattern_bd (x )
46+ x = np .transpose (x , (0 , 3 , 1 , 2 )).astype (np .float32 )
47+ return x .astype (original_dtype )
48+
49+ else :
50+
51+ def mod (x ):
52+ original_dtype = x .dtype
53+ x = add_pattern_bd (x )
54+ return x .astype (original_dtype )
4555
4656 backdoor = PoisoningAttackBackdoor (mod )
4757 target = y_train [0 ]
@@ -74,12 +84,21 @@ def test_check_params(art_warning, get_default_mnist_subset, image_dl_estimator)
7484 (x_train , y_train ), (_ , _ ) = get_default_mnist_subset
7585 classifier , _ = image_dl_estimator (functional = True )
7686
77- def mod (x ):
78- original_dtype = x .dtype
79- x = np .transpose (x , (0 , 2 , 3 , 1 )).astype (np .float32 )
80- x = add_pattern_bd (x )
81- x = np .transpose (x , (0 , 3 , 1 , 2 )).astype (np .float32 )
82- return x .astype (original_dtype )
87+ if isinstance (classifier , PyTorchClassifier ):
88+
89+ def mod (x ):
90+ original_dtype = x .dtype
91+ x = np .transpose (x , (0 , 2 , 3 , 1 )).astype (np .float32 )
92+ x = add_pattern_bd (x )
93+ x = np .transpose (x , (0 , 3 , 1 , 2 )).astype (np .float32 )
94+ return x .astype (original_dtype )
95+
96+ else :
97+
98+ def mod (x ):
99+ original_dtype = x .dtype
100+ x = add_pattern_bd (x )
101+ return x .astype (original_dtype )
83102
84103 backdoor = PoisoningAttackBackdoor (mod )
85104 target = y_train [0 ]
0 commit comments