Skip to content

Commit bf25a9e

Browse files
authored
Merge pull request #1693 from Trusted-AI/development_issue_1692
Update RobustDPatch for channels first images
2 parents c736e39 + 72f78bc commit bf25a9e

File tree

4 files changed

+44
-29
lines changed

4 files changed

+44
-29
lines changed

art/attacks/evasion/dpatch_robust.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,8 @@ def _augment_images_with_patch(
371371
if self.targeted:
372372
predictions = y_copy
373373
else:
374+
if channels_first:
375+
x_copy = np.transpose(x_copy, (0, 3, 1, 2))
374376
predictions = self.estimator.predict(x=x_copy, standardise_output=True)
375377

376378
for i_image in range(x_copy.shape[0]):
@@ -413,8 +415,12 @@ def _untransform_gradients(
413415
# Account for cropping when considering the upper left point of the patch:
414416
x_1 = self.patch_location[0] - int(transforms["crop_x"])
415417
y_1 = self.patch_location[1] - int(transforms["crop_y"])
416-
x_2 = x_1 + self.patch_shape[0]
417-
y_2 = y_1 + self.patch_shape[1]
418+
if channels_first:
419+
x_2 = x_1 + self.patch_shape[1]
420+
y_2 = y_1 + self.patch_shape[2]
421+
else:
422+
x_2 = x_1 + self.patch_shape[0]
423+
y_2 = y_1 + self.patch_shape[1]
418424
gradients = gradients[:, x_1:x_2, y_1:y_2, :]
419425

420426
if channels_first:

art/estimators/object_detection/pytorch_faster_rcnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(
7070
maximum values allowed for features. If floats are provided, these will be used as the range of all
7171
features. If arrays are provided, each value will be considered the bound for a feature, thus
7272
the shape of clip values needs to match the total number of features.
73-
:param channels_first: [Currently unused] Set channels first or last.
73+
:param channels_first: Set channels first or last.
7474
:param preprocessing_defences: Preprocessing defence(s) to be applied by the classifier.
7575
:param postprocessing_defences: Postprocessing defence(s) to be applied by the classifier.
7676
:param preprocessing: Tuple of the form `(subtrahend, divisor)` of floats or `np.ndarray` of values to be

art/estimators/object_detection/pytorch_object_detector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(
7474
maximum values allowed for features. If floats are provided, these will be used as the range of all
7575
features. If arrays are provided, each value will be considered the bound for a feature, thus
7676
the shape of clip values needs to match the total number of features.
77-
:param channels_first: [Currently unused] Set channels first or last.
77+
:param channels_first: Set channels first or last.
7878
:param preprocessing_defences: Preprocessing defence(s) to be applied by the classifier.
7979
:param postprocessing_defences: Postprocessing defence(s) to be applied by the classifier.
8080
:param preprocessing: Tuple of the form `(subtrahend, divisor)` of floats or `np.ndarray` of values to be

tests/attacks/evasion/test_dpatch_robust.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,9 @@ def test_generate_targeted(art_warning, fix_get_mnist_subset, fix_get_rcnn, fram
107107
art_warning(e)
108108

109109

110-
@pytest.mark.parametrize("image_format", ["NHWC", "NCHW"])
110+
@pytest.mark.parametrize("channels_first", [False, True])
111111
@pytest.mark.skip_framework("keras", "scikitlearn", "mxnet", "kerastf")
112-
def test_augment_images_with_patch(art_warning, image_format, fix_get_rcnn):
112+
def test_augment_images_with_patch(art_warning, channels_first, fix_get_rcnn):
113113
try:
114114
frcnn = fix_get_rcnn
115115
attack = RobustDPatch(
@@ -126,14 +126,12 @@ def test_augment_images_with_patch(art_warning, image_format, fix_get_rcnn):
126126
verbose=False,
127127
)
128128

129-
if image_format == "NHWC":
130-
patch = np.ones(shape=(4, 4, 1))
131-
x = np.zeros(shape=(1, 10, 10, 1))
132-
channels_first = False
133-
elif image_format == "NCHW":
129+
if channels_first:
134130
patch = np.ones(shape=(1, 4, 4))
135131
x = np.zeros(shape=(1, 1, 10, 10))
136-
channels_first = True
132+
else:
133+
patch = np.ones(shape=(4, 4, 1))
134+
x = np.zeros(shape=(1, 10, 10, 1))
137135

138136
patched_images, _, transformations = attack._augment_images_with_patch(
139137
x=x, y=None, patch=patch, channels_first=channels_first
@@ -143,10 +141,10 @@ def test_augment_images_with_patch(art_warning, image_format, fix_get_rcnn):
143141
patch_sum_expected = 16.0
144142
complement_sum_expected = 0.0
145143

146-
if image_format == "NHWC":
147-
patch_sum = np.sum(patched_images[0, 2:7, 2:7, :])
148-
elif image_format == "NCHW":
144+
if channels_first:
149145
patch_sum = np.sum(patched_images[0, :, 2:7, 2:7])
146+
else:
147+
patch_sum = np.sum(patched_images[0, 2:7, 2:7, :])
150148

151149
complement_sum = np.sum(patched_images[0]) - patch_sum
152150

@@ -196,13 +194,32 @@ def test_apply_patch(art_warning, fix_get_rcnn):
196194
art_warning(e)
197195

198196

197+
@pytest.mark.parametrize("channels_first", [False, True])
199198
@pytest.mark.skip_framework("keras", "scikitlearn", "mxnet", "kerastf")
200-
def test_untransform_gradients(art_warning, fix_get_rcnn):
199+
def test_untransform_gradients(art_warning, fix_get_rcnn, channels_first):
201200
try:
201+
crop_x = 1
202+
crop_y = 1
203+
rot90 = 3
204+
brightness = 1.0
205+
206+
if channels_first:
207+
patch_shape = (1, 4, 4)
208+
gradients = np.zeros(shape=(1, 1, 10, 10))
209+
gradients[:, :, 2:7, 2:7] = 1
210+
gradients = gradients[:, :, crop_x : 10 - crop_x, crop_y : 10 - crop_y]
211+
gradients = np.rot90(gradients, rot90, (2, 3))
212+
else:
213+
patch_shape = (4, 4, 1)
214+
gradients = np.zeros(shape=(1, 10, 10, 1))
215+
gradients[:, 2:7, 2:7, :] = 1
216+
gradients = gradients[:, crop_x : 10 - crop_x, crop_y : 10 - crop_y, :]
217+
gradients = np.rot90(gradients, rot90, (1, 2))
218+
202219
frcnn = fix_get_rcnn
203220
attack = RobustDPatch(
204221
frcnn,
205-
patch_shape=(4, 4, 1),
222+
patch_shape=patch_shape,
206223
patch_location=(2, 2),
207224
crop_range=(0, 0),
208225
brightness_range=(1.0, 1.0),
@@ -214,20 +231,12 @@ def test_untransform_gradients(art_warning, fix_get_rcnn):
214231
verbose=False,
215232
)
216233

217-
gradients = np.zeros(shape=(1, 10, 10, 1))
218-
gradients[:, 2:7, 2:7, :] = 1
219-
220-
crop_x = 1
221-
crop_y = 1
222-
rot90 = 3
223-
brightness = 1.0
224-
225-
gradients = gradients[:, crop_x : 10 - crop_x, crop_y : 10 - crop_y, :]
226-
gradients = np.rot90(gradients, rot90, (1, 2))
227-
228234
transforms = {"crop_x": crop_x, "crop_y": crop_y, "rot90": rot90, "brightness": brightness}
229235

230-
gradients = attack._untransform_gradients(gradients=gradients, transforms=transforms, channels_first=False)
236+
gradients = attack._untransform_gradients(
237+
gradients=gradients, transforms=transforms, channels_first=channels_first
238+
)
239+
231240
gradients_sum = np.sum(gradients[0])
232241
gradients_sum_expected = 16.0
233242

0 commit comments

Comments
 (0)