@@ -176,7 +176,6 @@ def resize_short(im, size):
176
176
:param size: the shorter edge size of image after resizing.
177
177
:type size: int
178
178
"""
179
- assert im .shape [- 1 ] == 1 or im .shape [- 1 ] == 3
180
179
h , w = im .shape [:2 ]
181
180
h_new , w_new = size , size
182
181
if h > w :
@@ -267,7 +266,7 @@ def random_crop(im, size, is_color=True):
267
266
return im
268
267
269
268
270
- def left_right_flip (im ):
269
+ def left_right_flip (im , is_color = True ):
271
270
"""
272
271
Flip an image along the horizontal direction.
273
272
Return the flipped image.
@@ -278,13 +277,16 @@ def left_right_flip(im):
278
277
279
278
im = left_right_flip(im)
280
279
281
- :paam im: input image with HWC layout
280
+ :paam im: input image with HWC layout or HW layout for gray image
281
+
282
282
:type im: ndarray
283
+ :paam is_color: whether color input image or not
284
+ :type is_color: bool
283
285
"""
284
- if len (im .shape ) == 3 :
286
+ if len (im .shape ) == 3 and is_color :
285
287
return im [:, ::- 1 , :]
286
288
else :
287
- return im [:, ::- 1 , : ]
289
+ return im [:, ::- 1 ]
288
290
289
291
290
292
def simple_transform (im ,
@@ -319,20 +321,22 @@ def simple_transform(im,
319
321
"""
320
322
im = resize_short (im , resize_size )
321
323
if is_train :
322
- im = random_crop (im , crop_size )
324
+ im = random_crop (im , crop_size , is_color )
323
325
if np .random .randint (2 ) == 0 :
324
- im = left_right_flip (im )
326
+ im = left_right_flip (im , is_color )
325
327
else :
326
- im = center_crop (im , crop_size )
328
+ im = center_crop (im , crop_size , is_color )
327
329
if len (im .shape ) == 3 :
328
330
im = to_chw (im )
329
331
330
332
im = im .astype ('float32' )
331
333
if mean is not None :
332
334
mean = np .array (mean , dtype = np .float32 )
333
335
# mean value, may be one value per channel
334
- if mean .ndim == 1 :
336
+ if mean .ndim == 1 and is_color :
335
337
mean = mean [:, np .newaxis , np .newaxis ]
338
+ elif mean .ndim == 1 :
339
+ mean = mean
336
340
else :
337
341
# elementwise mean
338
342
assert len (mean .shape ) == len (im )
@@ -372,6 +376,6 @@ def load_and_transform(filename,
372
376
mean values per channel.
373
377
:type mean: numpy array | list
374
378
"""
375
- im = load_image (filename )
379
+ im = load_image (filename , is_color )
376
380
im = simple_transform (im , resize_size , crop_size , is_train , is_color , mean )
377
381
return im
0 commit comments