@@ -99,15 +99,33 @@ def __init__(
9999 self .learning_rate = learning_rate
100100 self .max_iter = max_iter
101101 self .batch_size = batch_size
102- self .patch_shape = self .estimator .input_shape
103102 self .clip_patch = clip_patch
104103 self ._check_params ()
105104
105+ self .image_shape = self .estimator .input_shape
106+
107+ if self .estimator .channels_first :
108+ self .i_h = 1
109+ self .i_w = 2
110+ else :
111+ self .i_h = 0
112+ self .i_w = 1
113+
114+ if self .estimator .channels_first :
115+ smallest_image_edge = np .minimum (self .image_shape [1 ], self .image_shape [2 ])
116+ nb_channels = self .image_shape [0 ]
117+ self .patch_shape = (nb_channels , smallest_image_edge , smallest_image_edge )
118+ else :
119+ smallest_image_edge = np .minimum (self .image_shape [0 ], self .image_shape [1 ])
120+ nb_channels = self .image_shape [2 ]
121+ self .patch_shape = (smallest_image_edge , smallest_image_edge , nb_channels )
122+
123+ self .patch_shape = self .image_shape
124+
106125 mean_value = (self .estimator .clip_values [1 ] - self .estimator .clip_values [0 ]) / 2.0 + self .estimator .clip_values [
107126 0
108127 ]
109128 self .patch = np .ones (shape = self .patch_shape ).astype (np .float32 ) * mean_value
110- self .patch [int (self .patch .shape [0 ] / 2 ), :, :] = 1
111129
112130 def generate (self , x : np .ndarray , y : Optional [np .ndarray ] = None , ** kwargs ) -> np .ndarray :
113131 """
@@ -203,17 +221,27 @@ def _get_circular_patch_mask(self, sharpness: int = 40) -> np.ndarray:
203221 """
204222 Return a circular patch mask
205223 """
206- diameter = self .patch_shape [1 ]
224+ diameter = np .minimum (self .patch_shape [self .i_h ], self .patch_shape [self .i_w ])
225+
207226 x = np .linspace (- 1 , 1 , diameter )
208227 y = np .linspace (- 1 , 1 , diameter )
209228 x_grid , y_grid = np .meshgrid (x , y , sparse = True )
210229 z_grid = (x_grid ** 2 + y_grid ** 2 ) ** sharpness
211230
212231 mask = 1 - np .clip (z_grid , - 1 , 1 )
213232
214- pad_1 = int ((self .patch_shape [1 ] - mask .shape [1 ]) / 2 )
215- pad_2 = int (self .patch_shape [1 ] - pad_1 - mask .shape [1 ])
216- mask = np .pad (mask , pad_width = (pad_1 , pad_2 ), mode = "constant" , constant_values = (0 , 0 ))
233+ pad_h_before = int ((self .image_shape [self .i_h ] - mask .shape [self .i_h ]) / 2 )
234+ pad_h_after = int (self .image_shape [self .i_h ] - pad_h_before - mask .shape [self .i_h ])
235+
236+ pad_w_before = int ((self .image_shape [self .i_w ] - mask .shape [self .i_w ]) / 2 )
237+ pad_w_after = int (self .image_shape [self .i_w ] - pad_w_before - mask .shape [self .i_w ])
238+
239+ mask = np .pad (
240+ mask ,
241+ pad_width = ((pad_h_before , pad_h_after ), (pad_w_before , pad_w_after )),
242+ mode = "constant" ,
243+ constant_values = (0 , 0 ),
244+ )
217245
218246 channel_index = 1 if self .estimator .channels_first else 3
219247 axis = channel_index - 1
@@ -250,57 +278,67 @@ def _augment_images_with_random_patch(self, images, patch, scale=None):
250278 return patched_images , patch_mask_transformed_np , transformations
251279
252280 def _rotate (self , x , angle ):
253- axes = None
254- if not self .estimator .channels_first :
255- axes = (0 , 1 )
256- elif self .estimator .channels_first :
257- axes = (1 , 2 )
281+ axes = (self .i_h , self .i_w )
258282 return rotate (x , angle = angle , reshape = False , axes = axes , order = 1 )
259283
260- def _scale (self , x , scale , shape ):
284+ def _scale (self , x , scale ):
261285 zooms = None
262- if not self . estimator . channels_first :
263- zooms = ( scale , scale , 1.0 )
264- elif self .estimator .channels_first :
286+ height = None
287+ width = None
288+ if self .estimator .channels_first :
265289 zooms = (1.0 , scale , scale )
266- x = zoom (x , zoom = zooms , order = 1 )
267-
268- if x .shape [1 ] <= self .estimator .input_shape [1 ]:
269- pad_1 = int ((shape - x .shape [1 ]) / 2 )
270- pad_2 = int (shape - pad_1 - x .shape [1 ])
271- if not self .estimator .channels_first :
272- pad_width = ((pad_1 , pad_2 ), (pad_1 , pad_2 ), (0 , 0 ))
273- elif self .estimator .channels_first :
274- pad_width = ((0 , 0 ), (pad_1 , pad_2 ), (pad_1 , pad_2 ))
290+ height , width = self .patch_shape [1 :3 ]
291+ elif not self .estimator .channels_first :
292+ zooms = (scale , scale , 1.0 )
293+ height , width = self .patch_shape [0 :2 ]
294+
295+ if scale < 1.0 :
296+ scale_h = int (np .round (height * scale ))
297+ scale_w = int (np .round (width * scale ))
298+ top = (height - scale_h ) // 2
299+ left = (width - scale_w ) // 2
300+
301+ x_out = np .zeros_like (x )
302+ x_out [top : top + scale_h , left : left + scale_w ] = zoom (x , zoom = zooms , order = 1 )
303+
304+ if self .estimator .channels_first :
305+ x_out [:, top : top + scale_h , left : left + scale_w ] = zoom (x , zoom = zooms , order = 1 )
275306 else :
276- pad_width = None
277- x = np .pad (x , pad_width = pad_width , mode = "constant" , constant_values = (0 , 0 ))
278- else :
279- center = int (x .shape [1 ] / 2 )
280- patch_hw_1 = int (self .estimator .input_shape [1 ] / 2 )
281- patch_hw_2 = self .estimator .input_shape [1 ] - patch_hw_1
282- if not self .estimator .channels_first :
283- x = x [center - patch_hw_1 : center + patch_hw_2 , center - patch_hw_1 : center + patch_hw_2 , :]
284- elif self .estimator .channels_first :
285- x = x [:, center - patch_hw_1 : center + patch_hw_2 , center - patch_hw_1 : center + patch_hw_2 ]
307+ x_out [top : top + scale_h , left : left + scale_w , :] = zoom (x , zoom = zooms , order = 1 )
308+
309+ elif scale > 1.0 :
310+ scale_h = int (np .round (height / scale ))
311+ scale_w = int (np .round (width / scale ))
312+ top = (height - scale_h ) // 2
313+ left = (width - scale_w ) // 2
314+
315+ x_out = zoom (x [top : top + scale_h , left : left + scale_w ], zoom = zooms , order = 1 )
316+
317+ cut_top = (x_out .shape [self .i_h ] - height ) // 2
318+ cut_left = (x_out .shape [self .i_w ] - width ) // 2
319+
320+ if self .estimator .channels_first :
321+ x_out = x_out [:, cut_top : cut_top + height , cut_left : cut_left + width ]
286322 else :
287- x = None
323+ x_out = x_out [cut_top : cut_top + height , cut_left : cut_left + width , :]
324+
325+ else :
326+ x_out = x
288327
289- return x
328+ return x_out
329+
330+ def _shift (self , x , shift_h , shift_w ):
331+ if self .estimator .channels_first :
332+ shift_hw = (0 , shift_h , shift_w )
333+ else :
334+ shift_hw = (shift_h , shift_w , 0 )
290335
291- def _shift (self , x , shift_1 , shift_2 ):
292- shift_xy = None
293- if not self .estimator .channels_first :
294- shift_xy = (shift_1 , shift_2 , 0 )
295- elif self .estimator .channels_first :
296- shift_xy = (0 , shift_1 , shift_2 )
297- x = shift (x , shift = shift_xy , order = 1 )
298- return x , shift_1 , shift_2
336+ x = shift (x , shift = shift_hw , order = 1 )
337+ return x , shift_h , shift_w
299338
300339 def _random_transformation (self , patch , scale ):
301340 patch_mask = self ._get_circular_patch_mask ()
302341 transformation = dict ()
303- shape = patch_mask .shape [1 ]
304342
305343 # rotate
306344 angle = random .uniform (- self .rotation_max , self .rotation_max )
@@ -311,17 +349,18 @@ def _random_transformation(self, patch, scale):
311349 # scale
312350 if scale is None :
313351 scale = random .uniform (self .scale_min , self .scale_max )
314- patch = self ._scale (patch , scale , shape )
315- patch_mask = self ._scale (patch_mask , scale , shape )
352+ patch = self ._scale (patch , scale )
353+ patch_mask = self ._scale (patch_mask , scale )
316354 transformation ["scale" ] = scale
317355
318356 # shift
319- shift_max = (self .estimator .input_shape [1 ] - self .patch_shape [1 ] * scale ) / 2.0
320- if shift_max > 0 :
321- shift_1 = random .uniform (- shift_max , shift_max )
322- shift_2 = random .uniform (- shift_max , shift_max )
323- patch , _ , _ = self ._shift (patch , shift_1 , shift_2 )
324- patch_mask , shift_1 , shift_2 = self ._shift (patch_mask , shift_1 , shift_2 )
357+ shift_max_h = (self .estimator .input_shape [self .i_h ] - self .patch_shape [self .i_h ] * scale ) / 2.0
358+ shift_max_w = (self .estimator .input_shape [self .i_w ] - self .patch_shape [self .i_w ] * scale ) / 2.0
359+ if shift_max_h > 0 and shift_max_w > 0 :
360+ shift_h = random .uniform (- shift_max_h , shift_max_h )
361+ shift_w = random .uniform (- shift_max_w , shift_max_w )
362+ patch , _ , _ = self ._shift (patch , shift_h , shift_w )
363+ patch_mask , shift_1 , shift_2 = self ._shift (patch_mask , shift_h , shift_w )
325364 transformation ["shift_1" ] = shift_1
326365 transformation ["shift_2" ] = shift_2
327366 else :
@@ -330,17 +369,16 @@ def _random_transformation(self, patch, scale):
330369 return patch , patch_mask , transformation
331370
332371 def _reverse_transformation (self , gradients : np .ndarray , patch_mask_transformed , transformation ) -> np .ndarray :
333- shape = gradients .shape [1 ]
334372 gradients = gradients * patch_mask_transformed
335373
336374 # shift
337- shift_1 = transformation ["shift_1 " ]
338- shift_2 = transformation ["shift_2 " ]
339- gradients , _ , _ = self ._shift (gradients , - shift_1 , - shift_2 )
375+ shift_h = transformation ["shift_h " ]
376+ shift_w = transformation ["shift_w " ]
377+ gradients , _ , _ = self ._shift (gradients , - shift_h , - shift_w )
340378
341379 # scale
342380 scale = transformation ["scale" ]
343- gradients = self ._scale (gradients , 1.0 / scale , shape )
381+ gradients = self ._scale (gradients , 1.0 / scale )
344382
345383 # rotate
346384 angle = transformation ["rotate" ]
0 commit comments