Skip to content

Commit 79149bb

Browse files
author
Beat Buesser
committed
Add test to robust DPatch for channels first input
Signed-off-by: Beat Buesser <[email protected]>
1 parent 9f2c9e4 commit 79149bb

File tree

1 file changed

+34
-25
lines changed

1 file changed

+34
-25
lines changed

tests/attacks/evasion/test_dpatch_robust.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,9 @@ def test_generate_targeted(art_warning, fix_get_mnist_subset, fix_get_rcnn, fram
107107
art_warning(e)
108108

109109

110-
@pytest.mark.parametrize("image_format", ["NHWC", "NCHW"])
110+
@pytest.mark.parametrize("channels_first", [False, True])
111111
@pytest.mark.skip_framework("keras", "scikitlearn", "mxnet", "kerastf")
112-
def test_augment_images_with_patch(art_warning, image_format, fix_get_rcnn):
112+
def test_augment_images_with_patch(art_warning, channels_first, fix_get_rcnn):
113113
try:
114114
frcnn = fix_get_rcnn
115115
attack = RobustDPatch(
@@ -126,14 +126,12 @@ def test_augment_images_with_patch(art_warning, image_format, fix_get_rcnn):
126126
verbose=False,
127127
)
128128

129-
if image_format == "NHWC":
130-
patch = np.ones(shape=(4, 4, 1))
131-
x = np.zeros(shape=(1, 10, 10, 1))
132-
channels_first = False
133-
elif image_format == "NCHW":
129+
if channels_first:
134130
patch = np.ones(shape=(1, 4, 4))
135131
x = np.zeros(shape=(1, 1, 10, 10))
136-
channels_first = True
132+
else:
133+
patch = np.ones(shape=(4, 4, 1))
134+
x = np.zeros(shape=(1, 10, 10, 1))
137135

138136
patched_images, _, transformations = attack._augment_images_with_patch(
139137
x=x, y=None, patch=patch, channels_first=channels_first
@@ -143,10 +141,10 @@ def test_augment_images_with_patch(art_warning, image_format, fix_get_rcnn):
143141
patch_sum_expected = 16.0
144142
complement_sum_expected = 0.0
145143

146-
if image_format == "NHWC":
147-
patch_sum = np.sum(patched_images[0, 2:7, 2:7, :])
148-
elif image_format == "NCHW":
144+
if channels_first:
149145
patch_sum = np.sum(patched_images[0, :, 2:7, 2:7])
146+
else:
147+
patch_sum = np.sum(patched_images[0, 2:7, 2:7, :])
150148

151149
complement_sum = np.sum(patched_images[0]) - patch_sum
152150

@@ -196,13 +194,32 @@ def test_apply_patch(art_warning, fix_get_rcnn):
196194
art_warning(e)
197195

198196

197+
@pytest.mark.parametrize("channels_first", [False, True])
199198
@pytest.mark.skip_framework("keras", "scikitlearn", "mxnet", "kerastf")
200-
def test_untransform_gradients(art_warning, fix_get_rcnn):
199+
def test_untransform_gradients(art_warning, fix_get_rcnn, channels_first):
201200
try:
201+
crop_x = 1
202+
crop_y = 1
203+
rot90 = 3
204+
brightness = 1.0
205+
206+
if channels_first:
207+
patch_shape = (1, 4, 4)
208+
gradients = np.zeros(shape=(1, 1, 10, 10))
209+
gradients[:, :, 2:7, 2:7] = 1
210+
gradients = gradients[:, :, crop_x : 10 - crop_x, crop_y : 10 - crop_y]
211+
gradients = np.rot90(gradients, rot90, (2, 3))
212+
else:
213+
patch_shape = (4, 4, 1)
214+
gradients = np.zeros(shape=(1, 10, 10, 1))
215+
gradients[:, 2:7, 2:7, :] = 1
216+
gradients = gradients[:, crop_x : 10 - crop_x, crop_y : 10 - crop_y, :]
217+
gradients = np.rot90(gradients, rot90, (1, 2))
218+
202219
frcnn = fix_get_rcnn
203220
attack = RobustDPatch(
204221
frcnn,
205-
patch_shape=(4, 4, 1),
222+
patch_shape=patch_shape,
206223
patch_location=(2, 2),
207224
crop_range=(0, 0),
208225
brightness_range=(1.0, 1.0),
@@ -214,20 +231,12 @@ def test_untransform_gradients(art_warning, fix_get_rcnn):
214231
verbose=False,
215232
)
216233

217-
gradients = np.zeros(shape=(1, 10, 10, 1))
218-
gradients[:, 2:7, 2:7, :] = 1
219-
220-
crop_x = 1
221-
crop_y = 1
222-
rot90 = 3
223-
brightness = 1.0
224-
225-
gradients = gradients[:, crop_x : 10 - crop_x, crop_y : 10 - crop_y, :]
226-
gradients = np.rot90(gradients, rot90, (1, 2))
227-
228234
transforms = {"crop_x": crop_x, "crop_y": crop_y, "rot90": rot90, "brightness": brightness}
229235

230-
gradients = attack._untransform_gradients(gradients=gradients, transforms=transforms, channels_first=False)
236+
gradients = attack._untransform_gradients(
237+
gradients=gradients, transforms=transforms, channels_first=channels_first
238+
)
239+
231240
gradients_sum = np.sum(gradients[0])
232241
gradients_sum_expected = 16.0
233242

0 commit comments

Comments
 (0)