@@ -107,9 +107,9 @@ def test_generate_targeted(art_warning, fix_get_mnist_subset, fix_get_rcnn, fram
107107 art_warning (e )
108108
109109
110- @pytest .mark .parametrize ("image_format " , ["NHWC" , "NCHW" ])
110+ @pytest .mark .parametrize ("channels_first " , [False , True ])
111111@pytest .mark .skip_framework ("keras" , "scikitlearn" , "mxnet" , "kerastf" )
112- def test_augment_images_with_patch (art_warning , image_format , fix_get_rcnn ):
112+ def test_augment_images_with_patch (art_warning , channels_first , fix_get_rcnn ):
113113 try :
114114 frcnn = fix_get_rcnn
115115 attack = RobustDPatch (
@@ -126,14 +126,12 @@ def test_augment_images_with_patch(art_warning, image_format, fix_get_rcnn):
126126 verbose = False ,
127127 )
128128
129- if image_format == "NHWC" :
130- patch = np .ones (shape = (4 , 4 , 1 ))
131- x = np .zeros (shape = (1 , 10 , 10 , 1 ))
132- channels_first = False
133- elif image_format == "NCHW" :
129+ if channels_first :
134130 patch = np .ones (shape = (1 , 4 , 4 ))
135131 x = np .zeros (shape = (1 , 1 , 10 , 10 ))
136- channels_first = True
132+ else :
133+ patch = np .ones (shape = (4 , 4 , 1 ))
134+ x = np .zeros (shape = (1 , 10 , 10 , 1 ))
137135
138136 patched_images , _ , transformations = attack ._augment_images_with_patch (
139137 x = x , y = None , patch = patch , channels_first = channels_first
@@ -143,10 +141,10 @@ def test_augment_images_with_patch(art_warning, image_format, fix_get_rcnn):
143141 patch_sum_expected = 16.0
144142 complement_sum_expected = 0.0
145143
146- if image_format == "NHWC" :
147- patch_sum = np .sum (patched_images [0 , 2 :7 , 2 :7 , :])
148- elif image_format == "NCHW" :
144+ if channels_first :
149145 patch_sum = np .sum (patched_images [0 , :, 2 :7 , 2 :7 ])
146+ else :
147+ patch_sum = np .sum (patched_images [0 , 2 :7 , 2 :7 , :])
150148
151149 complement_sum = np .sum (patched_images [0 ]) - patch_sum
152150
@@ -196,13 +194,32 @@ def test_apply_patch(art_warning, fix_get_rcnn):
196194 art_warning (e )
197195
198196
197+ @pytest .mark .parametrize ("channels_first" , [False , True ])
199198@pytest .mark .skip_framework ("keras" , "scikitlearn" , "mxnet" , "kerastf" )
200- def test_untransform_gradients (art_warning , fix_get_rcnn ):
199+ def test_untransform_gradients (art_warning , fix_get_rcnn , channels_first ):
201200 try :
201+ crop_x = 1
202+ crop_y = 1
203+ rot90 = 3
204+ brightness = 1.0
205+
206+ if channels_first :
207+ patch_shape = (1 , 4 , 4 )
208+ gradients = np .zeros (shape = (1 , 1 , 10 , 10 ))
209+ gradients [:, :, 2 :7 , 2 :7 ] = 1
210+ gradients = gradients [:, :, crop_x : 10 - crop_x , crop_y : 10 - crop_y ]
211+ gradients = np .rot90 (gradients , rot90 , (2 , 3 ))
212+ else :
213+ patch_shape = (4 , 4 , 1 )
214+ gradients = np .zeros (shape = (1 , 10 , 10 , 1 ))
215+ gradients [:, 2 :7 , 2 :7 , :] = 1
216+ gradients = gradients [:, crop_x : 10 - crop_x , crop_y : 10 - crop_y , :]
217+ gradients = np .rot90 (gradients , rot90 , (1 , 2 ))
218+
202219 frcnn = fix_get_rcnn
203220 attack = RobustDPatch (
204221 frcnn ,
205- patch_shape = ( 4 , 4 , 1 ) ,
222+ patch_shape = patch_shape ,
206223 patch_location = (2 , 2 ),
207224 crop_range = (0 , 0 ),
208225 brightness_range = (1.0 , 1.0 ),
@@ -214,20 +231,12 @@ def test_untransform_gradients(art_warning, fix_get_rcnn):
214231 verbose = False ,
215232 )
216233
217- gradients = np .zeros (shape = (1 , 10 , 10 , 1 ))
218- gradients [:, 2 :7 , 2 :7 , :] = 1
219-
220- crop_x = 1
221- crop_y = 1
222- rot90 = 3
223- brightness = 1.0
224-
225- gradients = gradients [:, crop_x : 10 - crop_x , crop_y : 10 - crop_y , :]
226- gradients = np .rot90 (gradients , rot90 , (1 , 2 ))
227-
228234 transforms = {"crop_x" : crop_x , "crop_y" : crop_y , "rot90" : rot90 , "brightness" : brightness }
229235
230- gradients = attack ._untransform_gradients (gradients = gradients , transforms = transforms , channels_first = False )
236+ gradients = attack ._untransform_gradients (
237+ gradients = gradients , transforms = transforms , channels_first = channels_first
238+ )
239+
231240 gradients_sum = np .sum (gradients [0 ])
232241 gradients_sum_expected = 16.0
233242
0 commit comments