@@ -176,7 +176,6 @@ def resize_short(im, size):
176
176
:param size: the shorter edge size of image after resizing.
177
177
:type size: int
178
178
"""
179
- assert im .shape [- 1 ] == 1 or im .shape [- 1 ] == 3
180
179
h , w = im .shape [:2 ]
181
180
h_new , w_new = size , size
182
181
if h > w :
@@ -267,7 +266,7 @@ def random_crop(im, size, is_color=True):
267
266
return im
268
267
269
268
270
- def left_right_flip (im ):
269
+ def left_right_flip (im , is_color = True ):
271
270
"""
272
271
Flip an image along the horizontal direction.
273
272
Return the flipped image.
@@ -278,13 +277,15 @@ def left_right_flip(im):
278
277
279
278
im = left_right_flip(im)
280
279
281
- :paam im: input image with HWC layout
280
+ :param im: input image with HWC layout or HW layout for gray image
282
281
:type im: ndarray
282
+ :param is_color: whether input image is color or not
283
+ :type is_color: bool
283
284
"""
284
- if len (im .shape ) == 3 :
285
+ if len (im .shape ) == 3 and is_color :
285
286
return im [:, ::- 1 , :]
286
287
else :
287
- return im [:, ::- 1 , : ]
288
+ return im [:, ::- 1 ]
288
289
289
290
290
291
def simple_transform (im ,
@@ -321,8 +322,9 @@ def simple_transform(im,
321
322
if is_train :
322
323
im = random_crop (im , crop_size , is_color = is_color )
323
324
if np .random .randint (2 ) == 0 :
324
- im = left_right_flip (im )
325
+ im = left_right_flip (im , is_color )
325
326
else :
327
+ im = center_crop (im , crop_size , is_color )
326
328
im = center_crop (im , crop_size , is_color = is_color )
327
329
if len (im .shape ) == 3 :
328
330
im = to_chw (im )
@@ -331,8 +333,10 @@ def simple_transform(im,
331
333
if mean is not None :
332
334
mean = np .array (mean , dtype = np .float32 )
333
335
# mean value, may be one value per channel
334
- if mean .ndim == 1 :
336
+ if mean .ndim == 1 and is_color :
335
337
mean = mean [:, np .newaxis , np .newaxis ]
338
+ elif mean .ndim == 1 :
339
+ mean = mean
336
340
else :
337
341
# elementwise mean
338
342
assert len (mean .shape ) == len (im )
@@ -372,6 +376,6 @@ def load_and_transform(filename,
372
376
mean values per channel.
373
377
:type mean: numpy array | list
374
378
"""
375
- im = load_image (filename )
379
+ im = load_image (filename , is_color )
376
380
im = simple_transform (im , resize_size , crop_size , is_train , is_color , mean )
377
381
return im
0 commit comments