Skip to content

Commit 8cc270b

Browse files
committed
add comments: datasets
1 parent c047c55 commit 8cc270b

File tree

7 files changed

+116
-30
lines changed

7 files changed

+116
-30
lines changed

data/aligned_dataset.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@ class AlignedDataset(BaseDataset):
1414
"""
1515

1616
def __init__(self, opt):
17-
"""Initialize this dataset class."""
17+
"""Initialize this dataset class.
18+
19+
Parameters:
20+
opt -- options (needs to be a subclass of BaseOptions)
21+
"""
1822
BaseDataset.__init__(self, opt)
1923
self.dir_AB = os.path.join(opt.dataroot, opt.phase) # get the image directory
2024
self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size)) # get image paths
@@ -32,7 +36,7 @@ def __getitem__(self, index):
3236
Parameters:
3337
index - - a random integer for data indexing
3438
35-
Returns a dictionary of A, B, A_paths and B_paths
39+
Returns a dictionary that contains A, B, A_paths and B_paths
3640
A(tensor) - - an image in the input domain
3741
B(tensor) - - its corresponding image in the target domain
3842
A_paths(str) - - image paths

data/base_dataset.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,43 @@
1+
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2+
3+
It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4+
"""
15
import torch.utils.data as data
26
from PIL import Image
37
import torchvision.transforms as transforms
48
from abc import ABC, abstractmethod
59

610

711
class BaseDataset(data.Dataset, ABC):
12+
"""This class is an abstract base class (ABC) for datasets.
13+
14+
To create a subclass, you need to implement four functions:
15+
-- <__init__> (initialize the class, first call BaseDataset.__init__(self, opt))
16+
-- <__len__> (return the size of dataset)
17+
-- <__getitem__> (get a data point)
18+
-- (optionally) <modify_commandline_options> (add dataset-specific options and set default options).
19+
"""
20+
821
def __init__(self, opt):
22+
"""Initialize the class; save the options in the class
23+
24+
Parameters:
25+
opt -- options (needs to be a subclass of BaseOptions)
26+
"""
927
self.opt = opt
1028
self.root = opt.dataroot
1129

1230
@staticmethod
1331
def modify_commandline_options(parser, is_train):
32+
"""Add new dataset-specific options, and rewrite default values for existing options.
33+
34+
Parameters:
35+
parser -- original option parser
36+
is_train -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
37+
38+
Returns:
39+
the modified parser.
40+
"""
1441
return parser
1542

1643
@abstractmethod
@@ -20,13 +47,21 @@ def __len__(self):
2047

2148
@abstractmethod
2249
def __getitem__(self, index):
50+
"""Return a data point and its metadata information.
51+
52+
Parameters:
53+
index - - a random integer for data indexing
54+
55+
Returns:
56+
a dictionary of data with their names. It ususally contains the data itself and its metadata information.
57+
"""
2358
pass
2459

2560

2661
def get_transform(opt, grayscale=False, convert=True, crop=True, flip=True):
2762
"""Create a torchvision transformation function
2863
29-
The type of transformation is defined by option (e.g., [preprocess], [load_size], [crop_size])
64+
The type of transformation is defined by option(e.g., [preprocess], [load_size], [crop_size])
3065
and can be overwritten by arguments such as [convert], [crop], and [flip]
3166
"""
3267
transform_list = []
@@ -105,7 +140,7 @@ def __scale_width(img, target_width):
105140

106141

107142
def __print_size_warning(ow, oh, w, h):
108-
"""Print warning information about image size (only print once)"""
143+
"""Print warning information about image size(only print once)"""
109144
if not hasattr(__print_size_warning, 'has_printed'):
110145
print("The image size needs to be a multiple of 4. "
111146
"The loaded image size was (%d, %d), so it was adjusted to "

data/colorization_dataset.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88

99

1010
class ColorizationDataset(BaseDataset):
11+
"""This dataset class can load a set of nature images in RGB, and convert RGB format into (L, ab) pairs in Lab color space.
12+
13+
This dataset is required by pix2pix-based colorization model ('--model colorization')
14+
"""
1115
@staticmethod
1216
def modify_commandline_options(parser, is_train):
1317
parser.set_defaults(input_nc=1, output_nc=2, direction='AtoB')

data/image_folder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
"""Modified Image folder class
2-
Code from https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
3-
Modified the original code so that it also loads images from the current
4-
directory as well as the subdirectories
1+
"""A modified image folder class
2+
3+
We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4+
so that this class can load images from both current directory and its subdirectories.
55
"""
66

77
import torch.utils.data as data

data/single_dataset.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,37 @@
44

55

66
class SingleDataset(BaseDataset):
7-
@staticmethod
8-
def modify_commandline_options(parser, is_train):
9-
return parser
7+
"""This dataset class can load a set of images specified by the path --dataroot /path/to/data.
8+
9+
It can be used for generating CycleGAN results only for one side with the model option '-model test'.
10+
"""
1011

1112
def __init__(self, opt):
13+
"""Initialize this dataset class.
14+
15+
Parameters:
16+
opt -- options (needs to be a subclass of BaseOptions)
17+
"""
1218
BaseDataset.__init__(self, opt)
1319
self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size))
1420
input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
1521
self.transform = get_transform(opt, input_nc == 1)
1622

1723
def __getitem__(self, index):
24+
"""Return a data point and its metadata information.
25+
26+
Parameters:
27+
index - - a random integer for data indexing
28+
29+
Returns a dictionary that contains A and A_paths
30+
A(tensor) - - an image in one domain
31+
A_paths(str) - - the path of the image
32+
"""
1833
A_path = self.A_paths[index]
1934
A_img = Image.open(A_path).convert('RGB')
2035
A = self.transform(A_img)
2136
return {'A': A, 'A_paths': A_path}
2237

2338
def __len__(self):
39+
"""Return the total number of images in the dataset."""
2440
return len(self.A_paths)

data/template_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def modify_commandline_options(parser, is_train):
2323
"""Add new dataset-specific options, and rewrite default values for existing options.
2424
2525
Parameters:
26-
parser -- the option parser
27-
is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options.
26+
parser -- original option parser
27+
is_train -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
2828
2929
Returns:
3030
the modified parser.

data/unaligned_dataset.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,38 +6,65 @@
66

77

88
class UnalignedDataset(BaseDataset):
9-
@staticmethod
10-
def modify_commandline_options(parser, is_train):
11-
return parser
9+
"""
10+
This dataset class can load unaligned/unpaired datasets.
11+
12+
It requires two directories to host training images from domain A '/path/to/data/trainA'
13+
and from domain B '/path/to/data/trainB' respectively.
14+
You can train the model with the dataset flag '--dataroot /path/to/data'.
15+
Similarly, you need to prepare two directories:
16+
'/path/to/data/testA' and '/path/to/data/testB' during test time.
17+
"""
1218

1319
def __init__(self, opt):
20+
"""Initialize this dataset class.
21+
22+
Parameters:
23+
opt -- options (needs to be a subclass of BaseOptions)
24+
"""
1425
BaseDataset.__init__(self, opt)
15-
self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A')
16-
self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B')
26+
self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA'
27+
self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB'
1728

18-
self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size))
19-
self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size))
20-
self.A_size = len(self.A_paths)
21-
self.B_size = len(self.B_paths)
29+
self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
30+
self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
31+
self.A_size = len(self.A_paths) # get the size of dataset A
32+
self.B_size = len(self.B_paths) # get the size of dataset B
2233
btoA = self.opt.direction == 'BtoA'
23-
input_nc = self.opt.output_nc if btoA else self.opt.input_nc
24-
output_nc = self.opt.input_nc if btoA else self.opt.output_nc
25-
self.transform_A = get_transform(opt, input_nc == 1)
26-
self.transform_B = get_transform(opt, output_nc == 1)
34+
input_nc = self.opt.output_nc if btoA else self.opt.input_nc # get the number of channels of input image
35+
output_nc = self.opt.input_nc if btoA else self.opt.output_nc # get the number of channels of output image
36+
self.transform_A = get_transform(opt, grayscale=(input_nc == 1)) # if nc == 1, we convert RGB to grayscale image
37+
self.transform_B = get_transform(opt, grayscale=(output_nc == 1)) # if nc == 1, we convert RGB to grayscale image
2738

2839
def __getitem__(self, index):
29-
A_path = self.A_paths[index % self.A_size]
30-
if self.opt.serial_batches:
40+
"""Return a data point and its metadata information.
41+
42+
Parameters:
43+
index - - a random integer for data indexing
44+
45+
Returns a dictionary that contains A, B, A_paths and B_paths
46+
A(tensor) - - an image in the input domain
47+
B(tensor) - - its corresponding image in the target domain
48+
A_paths(str) - - image paths
49+
B_paths(str) - - image paths
50+
"""
51+
A_path = self.A_paths[index % self.A_size] # make sure index is within then range
52+
if self.opt.serial_batches: # make sure index is within then range
3153
index_B = index % self.B_size
32-
else:
54+
else: # randomize the index for domain B to avoid fixed pairs.
3355
index_B = random.randint(0, self.B_size - 1)
3456
B_path = self.B_paths[index_B]
3557
A_img = Image.open(A_path).convert('RGB')
3658
B_img = Image.open(B_path).convert('RGB')
37-
59+
# apply image transformation
3860
A = self.transform_A(A_img)
3961
B = self.transform_B(B_img)
4062
return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}
4163

4264
def __len__(self):
65+
"""Return the total number of images in the dataset.
66+
67+
As we have two datasets with potentially different number of images,
68+
we take a maximum of
69+
"""
4370
return max(self.A_size, self.B_size)

0 commit comments

Comments
 (0)