Skip to content

Commit cbe0857

Browse files
committed
improve dataset class: max_dataset_size and aligned_dataset
1 parent 8a76751 commit cbe0857

9 files changed

+76
-57
lines changed

data/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
"""This package """
12
import importlib
23
import torch.utils.data
34
from data.base_dataset import BaseDataset

data/aligned_dataset.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,67 @@
11
import os.path
22
import random
3-
from data.base_dataset import BaseDataset, get_simple_transform
4-
import torchvision.transforms.functional as TF
3+
from data.base_dataset import BaseDataset, get_transform
54
import torchvision.transforms as transforms
65
from data.image_folder import make_dataset
76
from PIL import Image
87

98

109
class AlignedDataset(BaseDataset):
11-
@staticmethod
12-
def modify_commandline_options(parser, is_train):
13-
return parser
10+
"""A dataset class for paired image dataset.
11+
12+
It assumes that the directory '/path/to/data/train' contains image pairs in the form of {A,B}.
13+
During test time, you need to prepare a directory /path/to/data/test.
14+
"""
1415

1516
def __init__(self, opt):
17+
"""Initialize this dataset class."""
1618
BaseDataset.__init__(self, opt)
17-
self.dir_AB = os.path.join(opt.dataroot, opt.phase)
18-
self.AB_paths = sorted(make_dataset(self.dir_AB))
19-
assert(opt.resize_or_crop == 'resize_and_crop') # only support this mode
20-
assert(self.opt.load_size >= self.opt.crop_size)
19+
self.dir_AB = os.path.join(opt.dataroot, opt.phase) # get the image directory
20+
self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size)) # get image paths
21+
assert(opt.resize_or_crop == 'resize_and_crop') # only support this mode
22+
assert(self.opt.load_size >= self.opt.crop_size) # crop_size should be smaller than the size of loaded image
2123
input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
2224
output_nc = self.opt.input_nc if self.opt.direction == 'BtoA' else self.opt.output_nc
23-
self.transform_A = get_simple_transform(grayscale=(input_nc == 1))
24-
self.transform_B = get_simple_transform(grayscale=(output_nc == 1))
25+
# we manually crop and flip in __getitem__ to make sure we apply the same crop and flip for image A and B
26+
# we disable the cropping and flipping in the function get_transform
27+
self.transform_A = get_transform(opt, grayscale=(input_nc == 1), crop=False, flip=False)
28+
self.transform_B = get_transform(opt, grayscale=(output_nc == 1), crop=False, flip=False)
2529

2630
def __getitem__(self, index):
31+
"""Return a data point and its metadata information.
32+
33+
Parameters:
34+
index - - a random integer for data indexing
35+
36+
Returns a dictionary of A, B, A_paths and B_paths
37+
A(tensor) - - an image in the input domain
38+
B(tensor) - - its corresponding image in the target domain
39+
A_paths(str) - - image paths
40+
B_paths(str) - - image paths
41+
"""
42+
# read a image given a random integer index
2743
AB_path = self.AB_paths[index]
2844
AB = Image.open(AB_path).convert('RGB')
45+
# split AB image into A and B
2946
w, h = AB.size
3047
w2 = int(w / 2)
31-
A0 = AB.crop((0, 0, w2, h)).resize((self.opt.load_size, self.opt.load_size), Image.BICUBIC)
32-
B0 = AB.crop((w2, 0, w, h)).resize((self.opt.load_size, self.opt.load_size), Image.BICUBIC)
33-
x, y, h, w = transforms.RandomCrop.get_params(A0, output_size=[self.opt.crop_size, self.opt.crop_size])
34-
A = TF.crop(A0, x, y, h, w)
35-
B = TF.crop(B0, x, y, h, w)
36-
48+
A = AB.crop((0, 0, w2, h)).resize((self.opt.load_size, self.opt.load_size), Image.BICUBIC)
49+
B = AB.crop((w2, 0, w, h)).resize((self.opt.load_size, self.opt.load_size), Image.BICUBIC)
50+
# apply the same cropping to both A and B
51+
if 'crop' in self.opt.resize_or_crop:
52+
x, y, h, w = transforms.RandomCrop.get_params(A, output_size=[self.opt.crop_size, self.opt.crop_size])
53+
A = A.crop((x, y, w, h))
54+
B = B.crop((x, y, w, h))
55+
# apply the same flipping to both A and B
3756
if (not self.opt.no_flip) and random.random() < 0.5:
38-
A = TF.hflip(A)
39-
B = TF.hflip(B)
57+
A = A.transpose(Image.FLIP_LEFT_RIGHT)
58+
B = B.transpose(Image.FLIP_LEFT_RIGHT)
59+
# call standard transformation function
4060
A = self.transform_A(A)
4161
B = self.transform_B(B)
62+
print(AB_path, index)
4263
return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path}
4364

4465
def __len__(self):
66+
"""Return the total number of images."""
4567
return len(self.AB_paths)

data/base_dataset.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,28 @@ def __getitem__(self, index):
2222
pass
2323

2424

25-
def get_transform(opt, grayscale=False, convert=True):
25+
def get_transform(opt, grayscale=False, convert=True, crop=True, flip=True):
2626
transform_list = []
2727
if grayscale:
2828
transform_list.append(transforms.Grayscale(1))
2929
if opt.resize_or_crop == 'resize_and_crop':
3030
osize = [opt.load_size, opt.load_size]
3131
transform_list.append(transforms.Resize(osize, Image.BICUBIC))
3232
transform_list.append(transforms.RandomCrop(opt.crop_size))
33-
elif opt.resize_or_crop == 'crop':
33+
elif opt.resize_or_crop == 'crop' and crop:
3434
transform_list.append(transforms.RandomCrop(opt.crop_size))
3535
elif opt.resize_or_crop == 'scale_width':
3636
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.crop_size)))
3737
elif opt.resize_or_crop == 'scale_width_and_crop':
3838
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size)))
39-
transform_list.append(transforms.RandomCrop(opt.crop_size))
39+
if crop:
40+
transform_list.append(transforms.RandomCrop(opt.crop_size))
4041
elif opt.resize_or_crop == 'none':
4142
transform_list.append(transforms.Lambda(lambda img: __adjust(img)))
4243
else:
4344
raise ValueError('--resize_or_crop %s is not a valid option.' % opt.resize_or_crop)
4445

45-
if not opt.no_flip:
46+
if not opt.no_flip and flip:
4647
transform_list.append(transforms.RandomHorizontalFlip())
4748
if convert:
4849
transform_list += [transforms.ToTensor(),
@@ -51,22 +52,14 @@ def get_transform(opt, grayscale=False, convert=True):
5152
return transforms.Compose(transform_list)
5253

5354

54-
def get_simple_transform(grayscale=False):
55-
transform_list = []
56-
if grayscale:
57-
transform_list.append(transforms.Grayscale(1))
58-
transform_list += [transforms.ToTensor(),
59-
transforms.Normalize((0.5, 0.5, 0.5),
60-
(0.5, 0.5, 0.5))]
61-
return transforms.Compose(transform_list)
62-
63-
6455
def __adjust(img):
65-
"""Modify the width and height to be multiple of 4"""
56+
"""Modify the width and height to be multiple of 4
57+
58+
the size needs to be a multiple of 4,
59+
because going through generator network may change img size
60+
and eventually cause size mismatch error
61+
"""
6662
ow, oh = img.size
67-
# the size needs to be a multiple of this number,
68-
# because going through generator network may change img size
69-
# and eventually cause size mismatch error
7063
mult = 4
7164
if ow % mult == 0 and oh % mult == 0:
7265
return img
@@ -82,11 +75,14 @@ def __adjust(img):
8275

8376

8477
def __scale_width(img, target_width):
78+
"""Resize images so that the output image width is the same as target width
79+
80+
the size needs to be a multiple of 4,
81+
because going through generator network may change img size
82+
and eventually cause size mismatch error
83+
"""
8584
ow, oh = img.size
8685

87-
# the size needs to be a multiple of this number,
88-
# because going through generator network may change img size
89-
# and eventually cause size mismatch error
9086
mult = 4
9187
assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult
9288
if (ow == target_width and oh % mult == 0):
@@ -103,6 +99,7 @@ def __scale_width(img, target_width):
10399

104100

105101
def __print_size_warning(ow, oh, w, h):
102+
"""Print warning information about image size (only print once)"""
106103
if not hasattr(__print_size_warning, 'has_printed'):
107104
print("The image size needs to be a multiple of 4. "
108105
"The loaded image size was (%d, %d), so it was adjusted to "

data/colorization_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def modify_commandline_options(parser, is_train):
1616
def __init__(self, opt):
1717
BaseDataset.__init__(self, opt)
1818
self.dir_A = os.path.join(opt.dataroot)
19-
self.A_paths = sorted(make_dataset(self.dir_A))
19+
self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size))
2020
assert(opt.input_nc == 1 and opt.output_nc == 2 and opt.direction == 'AtoB')
2121
self.transform = get_transform(opt, convert=False)
2222

data/image_folder.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def is_image_file(filename):
2020
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
2121

2222

23-
def make_dataset(dir):
23+
def make_dataset(dir, max_dataset_size=float("inf")):
2424
images = []
2525
assert os.path.isdir(dir), '%s is not a valid directory' % dir
2626

@@ -29,8 +29,7 @@ def make_dataset(dir):
2929
if is_image_file(fname):
3030
path = os.path.join(root, fname)
3131
images.append(path)
32-
33-
return images
32+
return images[:min(max_dataset_size, len(images))]
3433

3534

3635
def default_loader(path):

data/single_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def modify_commandline_options(parser, is_train):
1010

1111
def __init__(self, opt):
1212
BaseDataset.__init__(self, opt)
13-
self.A_paths = sorted(make_dataset(opt.dataroot))
13+
self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size))
1414
input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
1515
self.transform = get_transform(opt, input_nc == 1)
1616

data/template_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(self, opt):
4646
# save the option and dataset root
4747
BaseDataset.__init__(self, opt)
4848
# get the image paths of your dataset;
49-
self.image_paths = [] # You can call <sorted(make_dataset(self.root))> to get all the image paths under the directory self.root
49+
self.image_paths = [] # You can call <sorted(make_dataset(self.root, opt.max_dataset_size))> to get all the image paths under the directory self.root
5050
# define the default transform function. You can use <base_dataset.get_transform>; You can also define your custom transform function
5151
self.transform = get_transform(opt)
5252

@@ -57,7 +57,7 @@ def __getitem__(self, index):
5757
index -- a random integer for data indexing
5858
5959
Returns:
60-
a dicrtionary of data with their names. It ususally contains the data itself and its metadata information.
60+
a dictionary of data with their names. It ususally contains the data itself and its metadata information.
6161
6262
Step 1: get a random image path: e.g., path = self.image_paths[index]
6363
Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').

data/unaligned_dataset.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,8 @@ def __init__(self, opt):
1515
self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A')
1616
self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B')
1717

18-
self.A_paths = make_dataset(self.dir_A)
19-
self.B_paths = make_dataset(self.dir_B)
20-
21-
self.A_paths = sorted(self.A_paths)
22-
self.B_paths = sorted(self.B_paths)
18+
self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size))
19+
self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size))
2320
self.A_size = len(self.A_paths)
2421
self.B_size = len(self.B_paths)
2522
btoA = self.opt.direction == 'BtoA'

util/visualizer.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import sys
44
import ntpath
55
import time
6-
from . import util
7-
from . import html
6+
from . import util, html
7+
from subprocess import Popen, PIPE
88
from scipy.misc import imresize
99

1010
if sys.version_info[0] == 2:
@@ -41,11 +41,12 @@ def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
4141

4242
class Visualizer():
4343
def __init__(self, opt):
44+
self.opt = opt
4445
self.display_id = opt.display_id
4546
self.use_html = opt.isTrain and not opt.no_html
4647
self.win_size = opt.display_winsize
4748
self.name = opt.name
48-
self.opt = opt
49+
self.port = opt.display_port
4950
self.saved = False
5051
if self.display_id > 0:
5152
import visdom
@@ -66,8 +67,10 @@ def reset(self):
6667
self.saved = False
6768

6869
def throw_visdom_connection_error(self):
69-
print('\n\nCould not connect to Visdom server (https://github.com/facebookresearch/visdom) for displaying training progress.\nYou can suppress connection to Visdom using the option --display_id -1. To install visdom, run \n$ pip install visdom\n, and start the server by \n$ python -m visdom.server.\n\n')
70-
exit(1)
70+
cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
71+
print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
72+
print('Command: %s' % cmd)
73+
Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
7174

7275
# |visuals|: dictionary of images to display or save
7376
def display_current_results(self, visuals, epoch, save_result):

0 commit comments

Comments
 (0)