|
6 | 6 |
|
7 | 7 |
|
8 | 8 | class UnalignedDataset(BaseDataset):
|
9 |
| - @staticmethod |
10 |
| - def modify_commandline_options(parser, is_train): |
11 |
| - return parser |
| 9 | + """ |
| 10 | + This dataset class can load unaligned/unpaired datasets. |
| 11 | +
|
| 12 | + It requires two directories to host training images from domain A '/path/to/data/trainA' |
| 13 | + and from domain B '/path/to/data/trainB' respectively. |
| 14 | + You can train the model with the dataset flag '--dataroot /path/to/data'. |
| 15 | + Similarly, you need to prepare two directories: |
| 16 | + '/path/to/data/testA' and '/path/to/data/testB' during test time. |
| 17 | + """ |
12 | 18 |
|
13 | 19 | def __init__(self, opt):
|
| 20 | + """Initialize this dataset class. |
| 21 | +
|
| 22 | + Parameters: |
| 23 | + opt -- options (needs to be a subclass of BaseOptions) |
| 24 | + """ |
14 | 25 | BaseDataset.__init__(self, opt)
|
15 |
| - self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') |
16 |
| - self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') |
| 26 | + self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA' |
| 27 | + self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB' |
17 | 28 |
|
18 |
| - self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) |
19 |
| - self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) |
20 |
| - self.A_size = len(self.A_paths) |
21 |
| - self.B_size = len(self.B_paths) |
| 29 | + self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA' |
| 30 | + self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB' |
| 31 | + self.A_size = len(self.A_paths) # get the size of dataset A |
| 32 | + self.B_size = len(self.B_paths) # get the size of dataset B |
22 | 33 | btoA = self.opt.direction == 'BtoA'
|
23 |
| - input_nc = self.opt.output_nc if btoA else self.opt.input_nc |
24 |
| - output_nc = self.opt.input_nc if btoA else self.opt.output_nc |
25 |
| - self.transform_A = get_transform(opt, input_nc == 1) |
26 |
| - self.transform_B = get_transform(opt, output_nc == 1) |
| 34 | + input_nc = self.opt.output_nc if btoA else self.opt.input_nc # get the number of channels of input image |
| 35 | + output_nc = self.opt.input_nc if btoA else self.opt.output_nc # get the number of channels of output image |
| 36 | + self.transform_A = get_transform(opt, grayscale=(input_nc == 1)) # if nc == 1, we convert RGB to grayscale image |
| 37 | + self.transform_B = get_transform(opt, grayscale=(output_nc == 1)) # if nc == 1, we convert RGB to grayscale image |
27 | 38 |
|
28 | 39 | def __getitem__(self, index):
|
29 |
| - A_path = self.A_paths[index % self.A_size] |
30 |
| - if self.opt.serial_batches: |
| 40 | + """Return a data point and its metadata information. |
| 41 | +
|
| 42 | + Parameters: |
| 43 | + index - - a random integer for data indexing |
| 44 | +
|
| 45 | + Returns a dictionary that contains A, B, A_paths and B_paths |
| 46 | + A(tensor) - - an image in the input domain |
| 47 | + B(tensor) - - its corresponding image in the target domain |
| 48 | + A_paths(str) - - image paths |
| 49 | + B_paths(str) - - image paths |
| 50 | + """ |
| 51 | + A_path = self.A_paths[index % self.A_size] # make sure index is within then range |
| 52 | + if self.opt.serial_batches: # make sure index is within then range |
31 | 53 | index_B = index % self.B_size
|
32 |
| - else: |
| 54 | + else: # randomize the index for domain B to avoid fixed pairs. |
33 | 55 | index_B = random.randint(0, self.B_size - 1)
|
34 | 56 | B_path = self.B_paths[index_B]
|
35 | 57 | A_img = Image.open(A_path).convert('RGB')
|
36 | 58 | B_img = Image.open(B_path).convert('RGB')
|
37 |
| - |
| 59 | + # apply image transformation |
38 | 60 | A = self.transform_A(A_img)
|
39 | 61 | B = self.transform_B(B_img)
|
40 | 62 | return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}
|
41 | 63 |
|
42 | 64 | def __len__(self):
|
| 65 | + """Return the total number of images in the dataset. |
| 66 | +
|
| 67 | + As we have two datasets with potentially different number of images, |
| 68 | + we take a maximum of |
| 69 | + """ |
43 | 70 | return max(self.A_size, self.B_size)
|
0 commit comments