Skip to content

Commit 5d909b3

Browse files
committed
Test fix
Signed-off-by: Kevin Eykholt <[email protected]>
1 parent 1d27225 commit 5d909b3

File tree

1 file changed

+31
-12
lines changed

1 file changed

+31
-12
lines changed

tests/attacks/poison/test_hidden_trigger_backdoor.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from art.attacks.poisoning import HiddenTriggerBackdoor
2525
from art.attacks.poisoning import PoisoningAttackBackdoor
2626
from art.attacks.poisoning.perturbations import add_pattern_bd
27+
from art.estimators.classification.pytorch import PyTorchClassifier
2728

2829
from tests.utils import ARTTestException
2930

@@ -36,12 +37,21 @@ def test_poison(art_warning, get_default_mnist_subset, image_dl_estimator):
3637
(x_train, y_train), (_, _) = get_default_mnist_subset
3738
classifier, _ = image_dl_estimator(functional=True)
3839

39-
def mod(x):
40-
original_dtype = x.dtype
41-
x = np.transpose(x, (0, 2, 3, 1)).astype(np.float32)
42-
x = add_pattern_bd(x)
43-
x = np.transpose(x, (0, 3, 1, 2)).astype(np.float32)
44-
return x.astype(original_dtype)
40+
if isinstance(classifier, PyTorchClassifier):
41+
42+
def mod(x):
43+
original_dtype = x.dtype
44+
x = np.transpose(x, (0, 2, 3, 1)).astype(np.float32)
45+
x = add_pattern_bd(x)
46+
x = np.transpose(x, (0, 3, 1, 2)).astype(np.float32)
47+
return x.astype(original_dtype)
48+
49+
else:
50+
51+
def mod(x):
52+
original_dtype = x.dtype
53+
x = add_pattern_bd(x)
54+
return x.astype(original_dtype)
4555

4656
backdoor = PoisoningAttackBackdoor(mod)
4757
target = y_train[0]
@@ -74,12 +84,21 @@ def test_check_params(art_warning, get_default_mnist_subset, image_dl_estimator)
7484
(x_train, y_train), (_, _) = get_default_mnist_subset
7585
classifier, _ = image_dl_estimator(functional=True)
7686

77-
def mod(x):
78-
original_dtype = x.dtype
79-
x = np.transpose(x, (0, 2, 3, 1)).astype(np.float32)
80-
x = add_pattern_bd(x)
81-
x = np.transpose(x, (0, 3, 1, 2)).astype(np.float32)
82-
return x.astype(original_dtype)
87+
if isinstance(classifier, PyTorchClassifier):
88+
89+
def mod(x):
90+
original_dtype = x.dtype
91+
x = np.transpose(x, (0, 2, 3, 1)).astype(np.float32)
92+
x = add_pattern_bd(x)
93+
x = np.transpose(x, (0, 3, 1, 2)).astype(np.float32)
94+
return x.astype(original_dtype)
95+
96+
else:
97+
98+
def mod(x):
99+
original_dtype = x.dtype
100+
x = add_pattern_bd(x)
101+
return x.astype(original_dtype)
83102

84103
backdoor = PoisoningAttackBackdoor(mod)
85104
target = y_train[0]

0 commit comments

Comments
 (0)