diff --git a/cgi-bin/paint_x2_unet/img2imgDataset.py b/cgi-bin/paint_x2_unet/img2imgDataset.py index 86d0247..3d424a7 100755 --- a/cgi-bin/paint_x2_unet/img2imgDataset.py +++ b/cgi-bin/paint_x2_unet/img2imgDataset.py @@ -21,6 +21,23 @@ def cvt2YUV(img): img = cv2.cvtColor( img, cv2.COLOR_BGR2YUV ) return img +def cvt2GRAY(img): + if len(img.shape) == 2: + # Grayscale image + return img + width, height, color = img.shape + if color == 4: + # RGBA image + r, g, b, a = cv2.split(img) + white = (255 - a).repeat(3).reshape((width, height, 3)) + img2 = cv2.merge((r, g, b)).astype(np.uint32) + img2 += white + img2 = img2.clip(0, 255).astype(np.uint8) + return cv2.cvtColor(img2, cv2.COLOR_RGB2GRAY) + else: + # RGB image + return cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + class ImageAndRefDataset(chainer.dataset.DatasetMixin): def __init__(self, paths, root1='./input', root2='./ref', dtype=np.float32): @@ -39,7 +56,10 @@ def get_example(self, i, minimize=False, blur=0, s_size=128): path1 = os.path.join(self._root1, self._paths[i]) #image1 = ImageDataset._read_image_as_array(path1, self._dtype) - image1 = cv2.imread(path1, cv2.IMREAD_GRAYSCALE) + #image1 = cv2.imread(path1, cv2.IMREAD_GRAYSCALE) + image1 = cv2.imread(path1, cv2.IMREAD_UNCHANGED ) + image1 = cvt2GRAY(image1) + print("load:" + path1, os.path.isfile(path1), image1 is None) image1 = np.asarray(image1, self._dtype)