Skip to content

Commit ba80e92

Browse files
committed
simplify dataset loader code
1 parent d5cfa4d commit ba80e92

File tree

5 files changed

+27
-28
lines changed

5 files changed

+27
-28
lines changed

data/base_dataset.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,25 +78,30 @@ def get_params(opt, size):
7878
return {'crop_pos': (x, y), 'flip': flip}
7979

8080

81-
def get_transform(opt, params, grayscale=False, method=Image.BICUBIC, convert=True):
81+
def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
8282
transform_list = []
8383
if grayscale:
8484
transform_list.append(transforms.Grayscale(1))
8585
if 'resize' in opt.preprocess:
8686
osize = [opt.load_size, opt.load_size]
87-
transform_list.append(transforms.Scale(osize, method))
87+
transform_list.append(transforms.Resize(osize, method))
8888
elif 'scale_width' in opt.preprocess:
8989
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
9090

9191
if 'crop' in opt.preprocess:
92-
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
92+
if params is None:
93+
transform_list.append(transforms.RandomCrop(opt.crop_size))
94+
else:
95+
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
9396

9497
if opt.preprocess == 'none':
95-
base = 4
96-
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
98+
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
9799

98-
if not opt.no_flip and params['flip']:
99-
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
100+
if not opt.no_flip:
101+
if params is None:
102+
transform_list.append(transforms.RandomHorizontalFlip())
103+
elif params['flip']:
104+
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
100105

101106
if convert:
102107
transform_list += [transforms.ToTensor(),

data/colorization_dataset.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os.path
2-
from data.base_dataset import BaseDataset, get_params, get_transform
2+
from data.base_dataset import BaseDataset, get_transform
33
from data.image_folder import make_dataset
44
from skimage import color # require skimage
55
from PIL import Image
@@ -39,6 +39,7 @@ def __init__(self, opt):
3939
self.dir = os.path.join(opt.dataroot)
4040
self.AB_paths = sorted(make_dataset(self.dir, opt.max_dataset_size))
4141
assert(opt.input_nc == 1 and opt.output_nc == 2 and opt.direction == 'AtoB')
42+
self.transform = get_transform(self.opt, convert=False)
4243

4344
def __getitem__(self, index):
4445
"""Return a data point and its metadata information.
@@ -54,9 +55,7 @@ def __getitem__(self, index):
5455
"""
5556
path = self.AB_paths[index]
5657
im = Image.open(path).convert('RGB')
57-
transform_params = get_params(self.opt, im.size)
58-
transform = get_transform(self.opt, transform_params, convert=False)
59-
im = transform(im)
58+
im = self.transform(im)
6059
im = np.array(im)
6160
lab = color.rgb2lab(im).astype(np.float32)
6261
lab_t = transforms.ToTensor()(lab)

data/single_dataset.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from data.base_dataset import BaseDataset, get_params, get_transform
1+
from data.base_dataset import BaseDataset, get_transform
22
from data.image_folder import make_dataset
33
from PIL import Image
44

@@ -17,8 +17,8 @@ def __init__(self, opt):
1717
"""
1818
BaseDataset.__init__(self, opt)
1919
self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size))
20-
self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
21-
# self.transform = get_transform(opt, input_nc == 1)
20+
input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
21+
self.transform = get_transform(opt, grayscale=(input_nc == 1))
2222

2323
def __getitem__(self, index):
2424
"""Return a data point and its metadata information.
@@ -32,9 +32,7 @@ def __getitem__(self, index):
3232
"""
3333
A_path = self.A_paths[index]
3434
A_img = Image.open(A_path).convert('RGB')
35-
transform_params = get_params(self.opt, A_img.size)
36-
transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1))
37-
A = transform(A_img)
35+
A = self.transform(A_img)
3836
return {'A': A, 'A_paths': A_path}
3937

4038
def __len__(self):

data/unaligned_dataset.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os.path
2-
from data.base_dataset import BaseDataset, get_params, get_transform
2+
from data.base_dataset import BaseDataset, get_transform
33
from data.image_folder import make_dataset
44
from PIL import Image
55
import random
@@ -31,8 +31,10 @@ def __init__(self, opt):
3131
self.A_size = len(self.A_paths) # get the size of dataset A
3232
self.B_size = len(self.B_paths) # get the size of dataset B
3333
btoA = self.opt.direction == 'BtoA'
34-
self.input_nc = self.opt.output_nc if btoA else self.opt.input_nc # get the number of channels of input image
35-
self.output_nc = self.opt.input_nc if btoA else self.opt.output_nc # get the number of channels of output image
34+
input_nc = self.opt.output_nc if btoA else self.opt.input_nc # get the number of channels of input image
35+
output_nc = self.opt.input_nc if btoA else self.opt.output_nc # get the number of channels of output image
36+
self.transform_A = get_transform(self.opt, grayscale=(input_nc == 1))
37+
self.transform_B = get_transform(self.opt, grayscale=(output_nc == 1))
3638

3739
def __getitem__(self, index):
3840
"""Return a data point and its metadata information.
@@ -55,13 +57,8 @@ def __getitem__(self, index):
5557
A_img = Image.open(A_path).convert('RGB')
5658
B_img = Image.open(B_path).convert('RGB')
5759
# apply image transformation
58-
A_transform_params = get_params(self.opt, A_img.size)
59-
A_transform = get_transform(self.opt, A_transform_params, grayscale=(self.input_nc == 1))
60-
A = A_transform(A_img)
61-
62-
B_transform_params = get_params(self.opt, B_img.size)
63-
B_transform = get_transform(self.opt, B_transform_params, grayscale=(self.output_nc == 1))
64-
B = B_transform(B_img)
60+
A = self.transform_A(A_img)
61+
B = self.transform_B(B_img)
6562

6663
return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}
6764

models/base_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(self, opt):
4141
self.visual_names = []
4242
self.optimizers = []
4343
self.image_paths = []
44-
self.metric = 0 # used for learning rate policy 'plateau'
44+
self.metric = 0 # used for learning rate policy 'plateau'
4545

4646
@staticmethod
4747
def modify_commandline_options(parser, is_train):

0 commit comments

Comments
 (0)