Skip to content

Commit f27da7d

Browse files
committed
Refactored the Dataset code to support different preprocess options for pix2pix
1 parent 5762f5b commit f27da7d

File tree

5 files changed

+92
-109
lines changed

5 files changed

+92
-109
lines changed

data/aligned_dataset.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os.path
22
import random
3-
from data.base_dataset import BaseDataset, get_transform
3+
from data.base_dataset import BaseDataset, get_params, get_transform
44
import torchvision.transforms as transforms
55
from data.image_folder import make_dataset
66
from PIL import Image
@@ -23,12 +23,8 @@ def __init__(self, opt):
2323
self.dir_AB = os.path.join(opt.dataroot, opt.phase) # get the image directory
2424
self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size)) # get image paths
2525
assert(self.opt.load_size >= self.opt.crop_size) # crop_size should be smaller than the size of loaded image
26-
input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
27-
output_nc = self.opt.input_nc if self.opt.direction == 'BtoA' else self.opt.output_nc
28-
# we manually crop and flip in __getitem__ to make sure we apply the same crop and flip for image A and B
29-
# we disable the cropping and flipping in the function get_transform
30-
self.transform_A = get_transform(opt, grayscale=(input_nc == 1), crop=False, flip=False)
31-
self.transform_B = get_transform(opt, grayscale=(output_nc == 1), crop=False, flip=False)
26+
self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
27+
self.output_nc = self.opt.input_nc if self.opt.direction == 'BtoA' else self.opt.output_nc
3228

3329
def __getitem__(self, index):
3430
"""Return a data point and its metadata information.
@@ -48,20 +44,17 @@ def __getitem__(self, index):
4844
# split AB image into A and B
4945
w, h = AB.size
5046
w2 = int(w / 2)
51-
A = AB.crop((0, 0, w2, h)).resize((self.opt.load_size, self.opt.load_size), Image.BICUBIC)
52-
B = AB.crop((w2, 0, w, h)).resize((self.opt.load_size, self.opt.load_size), Image.BICUBIC)
53-
# apply the same cropping to both A and B
54-
if 'crop' in self.opt.preprocess:
55-
x, y, h, w = transforms.RandomCrop.get_params(A, output_size=[self.opt.crop_size, self.opt.crop_size])
56-
A = A.crop((x, y, w, h))
57-
B = B.crop((x, y, w, h))
58-
# apply the same flipping to both A and B
59-
if (not self.opt.no_flip) and random.random() < 0.5:
60-
A = A.transpose(Image.FLIP_LEFT_RIGHT)
61-
B = B.transpose(Image.FLIP_LEFT_RIGHT)
62-
# call standard transformation function
63-
A = self.transform_A(A)
64-
B = self.transform_B(B)
47+
A = AB.crop((0, 0, w2, h))
48+
B = AB.crop((w2, 0, w, h))
49+
50+
# apply the same transform to both A and B
51+
transform_params = get_params(self.opt, A.size)
52+
A_transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1))
53+
B_transform = get_transform(self.opt, transform_params, grayscale=(self.output_nc == 1))
54+
55+
A = A_transform(A)
56+
B = B_transform(B)
57+
6558
return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path}
6659

6760
def __len__(self):

data/base_dataset.py

Lines changed: 57 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
33
It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
44
"""
5+
import random
6+
import numpy as np
57
import torch.utils.data as data
68
from PIL import Image
79
import torchvision.transforms as transforms
@@ -58,103 +60,84 @@ def __getitem__(self, index):
5860
pass
5961

6062

61-
def get_transform(opt, grayscale=False, convert=True, crop=True, flip=True):
62-
"""Create a torchvision transformation function
63+
def get_params(opt, size):
64+
w, h = size
65+
new_h = h
66+
new_w = w
67+
if opt.preprocess == 'resize_and_crop':
68+
new_h = new_w = opt.load_size
69+
elif opt.preprocess == 'scale_width_and_crop':
70+
new_w = opt.load_size
71+
new_h = opt.load_size * h // w
6372

64-
The type of transformation is defined by option (e.g., [opt.preprocess], [opt.load_size], [opt.crop_size])
65-
and can be overwritten by arguments such as [convert], [crop], and [flip]
73+
x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
74+
y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
6675

67-
Parameters:
68-
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
69-
grayscale (bool) -- if convert input RGB image to a grayscale image
70-
convert (bool) -- if convert an image to a tensor array betwen [-1, 1]
71-
crop (bool) -- if apply cropping
72-
flip (bool) -- if apply horizontal flippling
73-
"""
76+
flip = random.random() > 0.5
77+
78+
return {'crop_pos': (x, y), 'flip': flip}
79+
80+
81+
def get_transform(opt, params, grayscale=False, method=Image.BICUBIC, convert=True):
7482
transform_list = []
7583
if grayscale:
7684
transform_list.append(transforms.Grayscale(1))
77-
if opt.preprocess == 'resize_and_crop':
85+
if 'resize' in opt.preprocess:
7886
osize = [opt.load_size, opt.load_size]
79-
transform_list.append(transforms.Resize(osize, Image.BICUBIC))
80-
transform_list.append(transforms.RandomCrop(opt.crop_size))
81-
elif opt.preprocess == 'crop' and crop:
82-
transform_list.append(transforms.RandomCrop(opt.crop_size))
83-
elif opt.preprocess == 'scale_width':
84-
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.crop_size)))
85-
elif opt.preprocess == 'scale_width_and_crop':
86-
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size)))
87-
if crop:
88-
transform_list.append(transforms.RandomCrop(opt.crop_size))
89-
elif opt.preprocess == 'none':
90-
transform_list.append(transforms.Lambda(lambda img: __adjust(img)))
91-
else:
92-
raise ValueError('--preprocess %s is not a valid option.' % opt.preprocess)
93-
94-
if not opt.no_flip and flip:
95-
transform_list.append(transforms.RandomHorizontalFlip())
87+
transform_list.append(transforms.Scale(osize, method))
88+
elif 'scale_width' in opt.preprocess:
89+
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
90+
91+
if 'crop' in opt.preprocess:
92+
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
93+
94+
if opt.preprocess == 'none':
95+
base = 4
96+
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
97+
98+
if not opt.no_flip and params['flip']:
99+
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
100+
96101
if convert:
97102
transform_list += [transforms.ToTensor(),
98103
transforms.Normalize((0.5, 0.5, 0.5),
99104
(0.5, 0.5, 0.5))]
100105
return transforms.Compose(transform_list)
101106

102107

103-
def __adjust(img):
104-
"""Modify the width and height to be multiple of 4.
105-
106-
Parameters:
107-
img (PIL image) -- input image
108-
109-
Returns a modified image whose width and height are mulitple of 4.
110-
111-
the size needs to be a multiple of 4,
112-
because going through generator network may change img size
113-
and eventually cause size mismatch error
114-
"""
108+
def __make_power_2(img, base, method=Image.BICUBIC):
115109
ow, oh = img.size
116-
mult = 4
117-
if ow % mult == 0 and oh % mult == 0:
110+
h = int(round(oh / base) * base)
111+
w = int(round(ow / base) * base)
112+
if (h == oh) and (w == ow):
118113
return img
119-
w = (ow - 1) // mult
120-
w = (w + 1) * mult
121-
h = (oh - 1) // mult
122-
h = (h + 1) * mult
123-
124-
if ow != w or oh != h:
125-
__print_size_warning(ow, oh, w, h)
126-
127-
return img.resize((w, h), Image.BICUBIC)
128-
129114

130-
def __scale_width(img, target_width):
131-
"""Resize images so that the width of the output image is the same as a target width
115+
__print_size_warning(ow, oh, w, h)
116+
return img.resize((w, h), method)
132117

133-
Parameters:
134-
img (PIL image) -- input image
135-
target_width (int) -- target image width
136118

137-
Returns a modified image whose width matches the target image width;
138-
139-
the size needs to be a multiple of 4,
140-
because going through generator network may change img size
141-
and eventually cause size mismatch error
142-
"""
119+
def __scale_width(img, target_width, method=Image.BICUBIC):
143120
ow, oh = img.size
144-
145-
mult = 4
146-
assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult
147-
if (ow == target_width and oh % mult == 0):
121+
if (ow == target_width):
148122
return img
149123
w = target_width
150-
target_height = int(target_width * oh / ow)
151-
m = (target_height - 1) // mult
152-
h = (m + 1) * mult
124+
h = int(target_width * oh / ow)
125+
return img.resize((w, h), method)
126+
127+
128+
def __crop(img, pos, size):
129+
ow, oh = img.size
130+
x1, y1 = pos
131+
tw = th = size
132+
if (ow > tw or oh > th):
133+
return img.crop((x1, y1, x1 + tw, y1 + th))
134+
return img
153135

154-
if target_height != h:
155-
__print_size_warning(target_width, target_height, w, h)
156136

157-
return img.resize((w, h), Image.BICUBIC)
137+
def __flip(img, flip):
138+
if flip:
139+
return img.transpose(Image.FLIP_LEFT_RIGHT)
140+
return img
158141

159142

160143
def __print_size_warning(ow, oh, w, h):

data/colorization_dataset.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os.path
2-
from data.base_dataset import BaseDataset, get_transform
2+
from data.base_dataset import BaseDataset, get_params, get_transform
33
from data.image_folder import make_dataset
44
from skimage import color # require skimage
55
from PIL import Image
@@ -39,7 +39,6 @@ def __init__(self, opt):
3939
self.dir = os.path.join(opt.dataroot)
4040
self.AB_paths = sorted(make_dataset(self.dir, opt.max_dataset_size))
4141
assert(opt.input_nc == 1 and opt.output_nc == 2 and opt.direction == 'AtoB')
42-
self.transform = get_transform(opt, convert=False)
4342

4443
def __getitem__(self, index):
4544
"""Return a data point and its metadata information.
@@ -55,7 +54,9 @@ def __getitem__(self, index):
5554
"""
5655
path = self.AB_paths[index]
5756
im = Image.open(path).convert('RGB')
58-
im = self.transform(im)
57+
transform_params = get_params(self.opt, im.size)
58+
transform = get_transform(self.opt, transform_params, convert=False)
59+
im = transform(im)
5960
im = np.array(im)
6061
lab = color.rgb2lab(im).astype(np.float32)
6162
lab_t = transforms.ToTensor()(lab)

data/single_dataset.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from data.base_dataset import BaseDataset, get_transform
1+
from data.base_dataset import BaseDataset, get_params, get_transform
22
from data.image_folder import make_dataset
33
from PIL import Image
44

@@ -17,8 +17,8 @@ def __init__(self, opt):
1717
"""
1818
BaseDataset.__init__(self, opt)
1919
self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size))
20-
input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
21-
self.transform = get_transform(opt, input_nc == 1)
20+
self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
21+
# self.transform = get_transform(opt, input_nc == 1)
2222

2323
def __getitem__(self, index):
2424
"""Return a data point and its metadata information.
@@ -32,7 +32,9 @@ def __getitem__(self, index):
3232
"""
3333
A_path = self.A_paths[index]
3434
A_img = Image.open(A_path).convert('RGB')
35-
A = self.transform(A_img)
35+
transform_params = get_params(self.opt, A_img.size)
36+
transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1))
37+
A = transform(A_img)
3638
return {'A': A, 'A_paths': A_path}
3739

3840
def __len__(self):

data/unaligned_dataset.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os.path
2-
from data.base_dataset import BaseDataset, get_transform
2+
from data.base_dataset import BaseDataset, get_params, get_transform
33
from data.image_folder import make_dataset
44
from PIL import Image
55
import random
@@ -31,10 +31,8 @@ def __init__(self, opt):
3131
self.A_size = len(self.A_paths) # get the size of dataset A
3232
self.B_size = len(self.B_paths) # get the size of dataset B
3333
btoA = self.opt.direction == 'BtoA'
34-
input_nc = self.opt.output_nc if btoA else self.opt.input_nc # get the number of channels of input image
35-
output_nc = self.opt.input_nc if btoA else self.opt.output_nc # get the number of channels of output image
36-
self.transform_A = get_transform(opt, grayscale=(input_nc == 1)) # if nc == 1, we convert RGB to grayscale image
37-
self.transform_B = get_transform(opt, grayscale=(output_nc == 1)) # if nc == 1, we convert RGB to grayscale image
34+
self.input_nc = self.opt.output_nc if btoA else self.opt.input_nc # get the number of channels of input image
35+
self.output_nc = self.opt.input_nc if btoA else self.opt.output_nc # get the number of channels of output image
3836

3937
def __getitem__(self, index):
4038
"""Return a data point and its metadata information.
@@ -57,8 +55,14 @@ def __getitem__(self, index):
5755
A_img = Image.open(A_path).convert('RGB')
5856
B_img = Image.open(B_path).convert('RGB')
5957
# apply image transformation
60-
A = self.transform_A(A_img)
61-
B = self.transform_B(B_img)
58+
A_transform_params = get_params(self.opt, A_img.size)
59+
A_transform = get_transform(self.opt, A_transform_params, grayscale=(self.input_nc == 1))
60+
A = A_transform(A_img)
61+
62+
B_transform_params = get_params(self.opt, B_img.size)
63+
B_transform = get_transform(self.opt, B_transform_params, grayscale=(self.output_nc == 1))
64+
B = B_transform(B_img)
65+
6266
return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}
6367

6468
def __len__(self):

0 commit comments

Comments
 (0)