Skip to content

Commit de797ad

Browse files
authored
Merge pull request #7786 from qingqing01/fix_image
Enhance image.py for gray image.
2 parents 7081f21 + 19ac570 commit de797ad

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

python/paddle/v2/image.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,6 @@ def resize_short(im, size):
176176
:param size: the shorter edge size of image after resizing.
177177
:type size: int
178178
"""
179-
assert im.shape[-1] == 1 or im.shape[-1] == 3
180179
h, w = im.shape[:2]
181180
h_new, w_new = size, size
182181
if h > w:
@@ -267,7 +266,7 @@ def random_crop(im, size, is_color=True):
267266
return im
268267

269268

270-
def left_right_flip(im):
269+
def left_right_flip(im, is_color=True):
271270
"""
272271
Flip an image along the horizontal direction.
273272
Return the flipped image.
@@ -278,13 +277,15 @@ def left_right_flip(im):
278277
279278
im = left_right_flip(im)
280279
281-
:paam im: input image with HWC layout
280+
:param im: input image with HWC layout or HW layout for gray image
282281
:type im: ndarray
282+
:param is_color: whether input image is color or not
283+
:type is_color: bool
283284
"""
284-
if len(im.shape) == 3:
285+
if len(im.shape) == 3 and is_color:
285286
return im[:, ::-1, :]
286287
else:
287-
return im[:, ::-1, :]
288+
return im[:, ::-1]
288289

289290

290291
def simple_transform(im,
@@ -321,8 +322,9 @@ def simple_transform(im,
321322
if is_train:
322323
im = random_crop(im, crop_size, is_color=is_color)
323324
if np.random.randint(2) == 0:
324-
im = left_right_flip(im)
325+
im = left_right_flip(im, is_color)
325326
else:
327+
im = center_crop(im, crop_size, is_color)
326328
im = center_crop(im, crop_size, is_color=is_color)
327329
if len(im.shape) == 3:
328330
im = to_chw(im)
@@ -331,8 +333,10 @@ def simple_transform(im,
331333
if mean is not None:
332334
mean = np.array(mean, dtype=np.float32)
333335
# mean value, may be one value per channel
334-
if mean.ndim == 1:
336+
if mean.ndim == 1 and is_color:
335337
mean = mean[:, np.newaxis, np.newaxis]
338+
elif mean.ndim == 1:
339+
mean = mean
336340
else:
337341
# elementwise mean
338342
assert len(mean.shape) == len(im)
@@ -372,6 +376,6 @@ def load_and_transform(filename,
372376
mean values per channel.
373377
:type mean: numpy array | list
374378
"""
375-
im = load_image(filename)
379+
im = load_image(filename, is_color)
376380
im = simple_transform(im, resize_size, crop_size, is_train, is_color, mean)
377381
return im

0 commit comments

Comments
 (0)