@@ -117,9 +117,8 @@ def generate( # pylint: disable=W0221
117117 mask = kwargs .get ("mask" )
118118 if mask is not None :
119119 mask = mask .copy ()
120- if (
121- mask is not None
122- and (mask .dtype != np .bool )
120+ if mask is not None and (
121+ mask .dtype != np .bool
123122 or not (mask .shape [0 ] == 1 or mask .shape [0 ] == x .shape [0 ])
124123 or not (
125124 (mask .shape [1 ] == x .shape [1 ] and mask .shape [2 ] == x .shape [2 ])
@@ -151,7 +150,12 @@ def generate( # pylint: disable=W0221
151150 self .target_label = target_label
152151
153152 patched_images , transforms = self ._augment_images_with_patch (
154- x , self ._patch , random_location = True , channels_first = self .estimator .channels_first , mask = mask
153+ x ,
154+ self ._patch ,
155+ random_location = True ,
156+ channels_first = self .estimator .channels_first ,
157+ mask = mask ,
158+ transforms = None ,
155159 )
156160 patch_target : List [Dict [str , np .ndarray ]] = list ()
157161
@@ -237,6 +241,15 @@ def generate( # pylint: disable=W0221
237241 a_max = self .estimator .clip_values [1 ],
238242 )
239243
244+ patched_images , _ = self ._augment_images_with_patch (
245+ x ,
246+ self ._patch ,
247+ random_location = False ,
248+ channels_first = self .estimator .channels_first ,
249+ mask = None ,
250+ transforms = transforms ,
251+ )
252+
240253 return self ._patch
241254
242255 @staticmethod
@@ -246,6 +259,7 @@ def _augment_images_with_patch(
246259 random_location : bool ,
247260 channels_first : bool ,
248261 mask : Optional [np .ndarray ] = None ,
262+ transforms : List [Dict [str , int ]] = None ,
249263 ) -> Tuple [np .ndarray , List [Dict [str , int ]]]:
250264 """
251265 Augment images with patch.
@@ -258,9 +272,16 @@ def _augment_images_with_patch(
258272 :param mask: An boolean array of shape equal to the shape of a single samples (1, H, W) or the shape of `x`
259273 (N, H, W) without their channel dimensions. Any features for which the mask is True can be the
260274 center location of the patch during sampling.
275+ :param transforms: Patch transforms, requires `random_location=False`, and `mask=None`.
261276 :type mask: `np.ndarray`
262277 """
263- transformations = list ()
278+ if transforms is not None :
279+ if random_location or mask is not None :
280+ raise ValueError (
281+ "Definition of patch locations in `locations` requires `random_location=False`, and `mask=None`."
282+ )
283+
284+ random_transformations = list ()
264285 x_copy = x .copy ()
265286 patch_copy = patch .copy ()
266287
@@ -270,48 +291,56 @@ def _augment_images_with_patch(
270291
271292 for i_image in range (x .shape [0 ]):
272293
273- if random_location :
274- if mask is None :
275- i_x_1 = random .randint (0 , x_copy .shape [1 ] - 1 - patch_copy .shape [0 ])
276- i_y_1 = random .randint (0 , x_copy .shape [2 ] - 1 - patch_copy .shape [1 ])
277- else :
294+ if transforms is None :
278295
279- if mask .shape [0 ] == 1 :
280- mask_2d = mask [0 , :, :]
296+ if random_location :
297+ if mask is None :
298+ i_x_1 = random .randint (0 , x_copy .shape [1 ] - 1 - patch_copy .shape [0 ])
299+ i_y_1 = random .randint (0 , x_copy .shape [2 ] - 1 - patch_copy .shape [1 ])
281300 else :
282- mask_2d = mask [i_image , :, :]
283301
284- edge_x_0 = patch_copy .shape [0 ] // 2
285- edge_x_1 = patch_copy . shape [ 0 ] - edge_x_0
286- edge_y_0 = patch_copy . shape [ 1 ] // 2
287- edge_y_1 = patch_copy . shape [ 1 ] - edge_y_0
302+ if mask .shape [0 ] == 1 :
303+ mask_2d = mask [ 0 , :, :]
304+ else :
305+ mask_2d = mask [ i_image , :, :]
288306
289- mask_2d [ 0 : edge_x_0 , :] = False
290- mask_2d [ - edge_x_1 :, :] = False
291- mask_2d [:, 0 : edge_y_0 ] = False
292- mask_2d [:, - edge_y_1 :] = False
307+ edge_x_0 = patch_copy . shape [ 0 ] // 2
308+ edge_x_1 = patch_copy . shape [ 0 ] - edge_x_0
309+ edge_y_0 = patch_copy . shape [ 1 ] // 2
310+ edge_y_1 = patch_copy . shape [ 1 ] - edge_y_0
293311
294- num_pos = np .argwhere (mask_2d ).shape [0 ]
295- pos_id = np .random .choice (num_pos , size = 1 )
296- pos = np .argwhere (mask_2d > 0 )[pos_id [0 ]]
297- i_x_1 = pos [0 ] - edge_x_0
298- i_y_1 = pos [1 ] - edge_y_0
312+ mask_2d [0 :edge_x_0 , :] = False
313+ mask_2d [- edge_x_1 :, :] = False
314+ mask_2d [:, 0 :edge_y_0 ] = False
315+ mask_2d [:, - edge_y_1 :] = False
299316
300- else :
301- i_x_1 = 0
302- i_y_1 = 0
317+ num_pos = np .argwhere (mask_2d ).shape [0 ]
318+ pos_id = np .random .choice (num_pos , size = 1 )
319+ pos = np .argwhere (mask_2d > 0 )[pos_id [0 ]]
320+ i_x_1 = pos [0 ] - edge_x_0
321+ i_y_1 = pos [1 ] - edge_y_0
303322
304- i_x_2 = i_x_1 + patch_copy .shape [0 ]
305- i_y_2 = i_y_1 + patch_copy .shape [1 ]
323+ else :
324+ i_x_1 = 0
325+ i_y_1 = 0
326+
327+ i_x_2 = i_x_1 + patch_copy .shape [0 ]
328+ i_y_2 = i_y_1 + patch_copy .shape [1 ]
306329
307- transformations .append ({"i_x_1" : i_x_1 , "i_y_1" : i_y_1 , "i_x_2" : i_x_2 , "i_y_2" : i_y_2 })
330+ random_transformations .append ({"i_x_1" : i_x_1 , "i_y_1" : i_y_1 , "i_x_2" : i_x_2 , "i_y_2" : i_y_2 })
331+
332+ else :
333+ i_x_1 = transforms [i_image ]["i_x_1" ]
334+ i_x_2 = transforms [i_image ]["i_x_2" ]
335+ i_y_1 = transforms [i_image ]["i_y_1" ]
336+ i_y_2 = transforms [i_image ]["i_y_2" ]
308337
309338 x_copy [i_image , i_x_1 :i_x_2 , i_y_1 :i_y_2 , :] = patch_copy
310339
311340 if channels_first :
312341 x_copy = np .transpose (x_copy , (0 , 3 , 1 , 2 ))
313342
314- return x_copy , transformations
343+ return x_copy , random_transformations
315344
316345 def apply_patch (
317346 self ,
0 commit comments