|
2 | 2 |
|
3 | 3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
|
4 | 4 | """
|
| 5 | +import random |
| 6 | +import numpy as np |
5 | 7 | import torch.utils.data as data
|
6 | 8 | from PIL import Image
|
7 | 9 | import torchvision.transforms as transforms
|
@@ -58,103 +60,84 @@ def __getitem__(self, index):
|
58 | 60 | pass
|
59 | 61 |
|
60 | 62 |
|
61 |
| -def get_transform(opt, grayscale=False, convert=True, crop=True, flip=True): |
62 |
| - """Create a torchvision transformation function |
| 63 | +def get_params(opt, size): |
| 64 | + w, h = size |
| 65 | + new_h = h |
| 66 | + new_w = w |
| 67 | + if opt.preprocess == 'resize_and_crop': |
| 68 | + new_h = new_w = opt.load_size |
| 69 | + elif opt.preprocess == 'scale_width_and_crop': |
| 70 | + new_w = opt.load_size |
| 71 | + new_h = opt.load_size * h // w |
63 | 72 |
|
64 |
| - The type of transformation is defined by option (e.g., [opt.preprocess], [opt.load_size], [opt.crop_size]) |
65 |
| - and can be overwritten by arguments such as [convert], [crop], and [flip] |
| 73 | + x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) |
| 74 | + y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) |
66 | 75 |
|
67 |
| - Parameters: |
68 |
| - opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions |
69 |
| - grayscale (bool) -- if convert input RGB image to a grayscale image |
70 |
| - convert (bool) -- if convert an image to a tensor array betwen [-1, 1] |
71 |
| - crop (bool) -- if apply cropping |
72 |
| - flip (bool) -- if apply horizontal flippling |
73 |
| - """ |
| 76 | + flip = random.random() > 0.5 |
| 77 | + |
| 78 | + return {'crop_pos': (x, y), 'flip': flip} |
| 79 | + |
| 80 | + |
| 81 | +def get_transform(opt, params, grayscale=False, method=Image.BICUBIC, convert=True): |
74 | 82 | transform_list = []
|
75 | 83 | if grayscale:
|
76 | 84 | transform_list.append(transforms.Grayscale(1))
|
77 |
| - if opt.preprocess == 'resize_and_crop': |
| 85 | + if 'resize' in opt.preprocess: |
78 | 86 | osize = [opt.load_size, opt.load_size]
|
79 |
| - transform_list.append(transforms.Resize(osize, Image.BICUBIC)) |
80 |
| - transform_list.append(transforms.RandomCrop(opt.crop_size)) |
81 |
| - elif opt.preprocess == 'crop' and crop: |
82 |
| - transform_list.append(transforms.RandomCrop(opt.crop_size)) |
83 |
| - elif opt.preprocess == 'scale_width': |
84 |
| - transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.crop_size))) |
85 |
| - elif opt.preprocess == 'scale_width_and_crop': |
86 |
| - transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size))) |
87 |
| - if crop: |
88 |
| - transform_list.append(transforms.RandomCrop(opt.crop_size)) |
89 |
| - elif opt.preprocess == 'none': |
90 |
| - transform_list.append(transforms.Lambda(lambda img: __adjust(img))) |
91 |
| - else: |
92 |
| - raise ValueError('--preprocess %s is not a valid option.' % opt.preprocess) |
93 |
| - |
94 |
| - if not opt.no_flip and flip: |
95 |
| - transform_list.append(transforms.RandomHorizontalFlip()) |
| 87 | + transform_list.append(transforms.Scale(osize, method)) |
| 88 | + elif 'scale_width' in opt.preprocess: |
| 89 | + transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) |
| 90 | + |
| 91 | + if 'crop' in opt.preprocess: |
| 92 | + transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) |
| 93 | + |
| 94 | + if opt.preprocess == 'none': |
| 95 | + base = 4 |
| 96 | + transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) |
| 97 | + |
| 98 | + if not opt.no_flip and params['flip']: |
| 99 | + transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) |
| 100 | + |
96 | 101 | if convert:
|
97 | 102 | transform_list += [transforms.ToTensor(),
|
98 | 103 | transforms.Normalize((0.5, 0.5, 0.5),
|
99 | 104 | (0.5, 0.5, 0.5))]
|
100 | 105 | return transforms.Compose(transform_list)
|
101 | 106 |
|
102 | 107 |
|
103 |
| -def __adjust(img): |
104 |
| - """Modify the width and height to be multiple of 4. |
105 |
| -
|
106 |
| - Parameters: |
107 |
| - img (PIL image) -- input image |
108 |
| -
|
109 |
| - Returns a modified image whose width and height are mulitple of 4. |
110 |
| -
|
111 |
| - the size needs to be a multiple of 4, |
112 |
| - because going through generator network may change img size |
113 |
| - and eventually cause size mismatch error |
114 |
| - """ |
| 108 | +def __make_power_2(img, base, method=Image.BICUBIC): |
115 | 109 | ow, oh = img.size
|
116 |
| - mult = 4 |
117 |
| - if ow % mult == 0 and oh % mult == 0: |
| 110 | + h = int(round(oh / base) * base) |
| 111 | + w = int(round(ow / base) * base) |
| 112 | + if (h == oh) and (w == ow): |
118 | 113 | return img
|
119 |
| - w = (ow - 1) // mult |
120 |
| - w = (w + 1) * mult |
121 |
| - h = (oh - 1) // mult |
122 |
| - h = (h + 1) * mult |
123 |
| - |
124 |
| - if ow != w or oh != h: |
125 |
| - __print_size_warning(ow, oh, w, h) |
126 |
| - |
127 |
| - return img.resize((w, h), Image.BICUBIC) |
128 |
| - |
129 | 114 |
|
130 |
| -def __scale_width(img, target_width): |
131 |
| - """Resize images so that the width of the output image is the same as a target width |
| 115 | + __print_size_warning(ow, oh, w, h) |
| 116 | + return img.resize((w, h), method) |
132 | 117 |
|
133 |
| - Parameters: |
134 |
| - img (PIL image) -- input image |
135 |
| - target_width (int) -- target image width |
136 | 118 |
|
137 |
| - Returns a modified image whose width matches the target image width; |
138 |
| -
|
139 |
| - the size needs to be a multiple of 4, |
140 |
| - because going through generator network may change img size |
141 |
| - and eventually cause size mismatch error |
142 |
| - """ |
| 119 | +def __scale_width(img, target_width, method=Image.BICUBIC): |
143 | 120 | ow, oh = img.size
|
144 |
| - |
145 |
| - mult = 4 |
146 |
| - assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult |
147 |
| - if (ow == target_width and oh % mult == 0): |
| 121 | + if (ow == target_width): |
148 | 122 | return img
|
149 | 123 | w = target_width
|
150 |
| - target_height = int(target_width * oh / ow) |
151 |
| - m = (target_height - 1) // mult |
152 |
| - h = (m + 1) * mult |
| 124 | + h = int(target_width * oh / ow) |
| 125 | + return img.resize((w, h), method) |
| 126 | + |
| 127 | + |
| 128 | +def __crop(img, pos, size): |
| 129 | + ow, oh = img.size |
| 130 | + x1, y1 = pos |
| 131 | + tw = th = size |
| 132 | + if (ow > tw or oh > th): |
| 133 | + return img.crop((x1, y1, x1 + tw, y1 + th)) |
| 134 | + return img |
153 | 135 |
|
154 |
| - if target_height != h: |
155 |
| - __print_size_warning(target_width, target_height, w, h) |
156 | 136 |
|
157 |
| - return img.resize((w, h), Image.BICUBIC) |
| 137 | +def __flip(img, flip): |
| 138 | + if flip: |
| 139 | + return img.transpose(Image.FLIP_LEFT_RIGHT) |
| 140 | + return img |
158 | 141 |
|
159 | 142 |
|
160 | 143 | def __print_size_warning(ow, oh, w, h):
|
|
0 commit comments