Skip to content

Commit e845a1d

Browse files
committed
Enhance image.py for gray image.
1 parent 5752892 commit e845a1d

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

python/paddle/v2/image.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,6 @@ def resize_short(im, size):
176176
:param size: the shorter edge size of image after resizing.
177177
:type size: int
178178
"""
179-
assert im.shape[-1] == 1 or im.shape[-1] == 3
180179
h, w = im.shape[:2]
181180
h_new, w_new = size, size
182181
if h > w:
@@ -267,7 +266,7 @@ def random_crop(im, size, is_color=True):
267266
return im
268267

269268

270-
def left_right_flip(im):
269+
def left_right_flip(im, is_color=True):
271270
"""
272271
Flip an image along the horizontal direction.
273272
Return the flipped image.
@@ -278,13 +277,16 @@ def left_right_flip(im):
278277
279278
im = left_right_flip(im)
280279
281-
:paam im: input image with HWC layout
280+
:paam im: input image with HWC layout or HW layout for gray image
281+
282282
:type im: ndarray
283+
:paam is_color: whether color input image or not
284+
:type is_color: bool
283285
"""
284-
if len(im.shape) == 3:
286+
if len(im.shape) == 3 and is_color:
285287
return im[:, ::-1, :]
286288
else:
287-
return im[:, ::-1, :]
289+
return im[:, ::-1]
288290

289291

290292
def simple_transform(im,
@@ -319,20 +321,22 @@ def simple_transform(im,
319321
"""
320322
im = resize_short(im, resize_size)
321323
if is_train:
322-
im = random_crop(im, crop_size)
324+
im = random_crop(im, crop_size, is_color)
323325
if np.random.randint(2) == 0:
324-
im = left_right_flip(im)
326+
im = left_right_flip(im, is_color)
325327
else:
326-
im = center_crop(im, crop_size)
328+
im = center_crop(im, crop_size, is_color)
327329
if len(im.shape) == 3:
328330
im = to_chw(im)
329331

330332
im = im.astype('float32')
331333
if mean is not None:
332334
mean = np.array(mean, dtype=np.float32)
333335
# mean value, may be one value per channel
334-
if mean.ndim == 1:
336+
if mean.ndim == 1 and is_color:
335337
mean = mean[:, np.newaxis, np.newaxis]
338+
elif mean.ndim == 1:
339+
mean = mean
336340
else:
337341
# elementwise mean
338342
assert len(mean.shape) == len(im)
@@ -372,6 +376,6 @@ def load_and_transform(filename,
372376
mean values per channel.
373377
:type mean: numpy array | list
374378
"""
375-
im = load_image(filename)
379+
im = load_image(filename, is_color)
376380
im = simple_transform(im, resize_size, crop_size, is_train, is_color, mean)
377381
return im

0 commit comments

Comments
 (0)