|
1 | 1 | import os.path
|
2 | 2 | import random
|
3 |
| -from data.base_dataset import BaseDataset, get_simple_transform |
4 |
| -import torchvision.transforms.functional as TF |
| 3 | +from data.base_dataset import BaseDataset, get_transform |
5 | 4 | import torchvision.transforms as transforms
|
6 | 5 | from data.image_folder import make_dataset
|
7 | 6 | from PIL import Image
|
8 | 7 |
|
9 | 8 |
|
10 | 9 | class AlignedDataset(BaseDataset):
|
11 |
| - @staticmethod |
12 |
| - def modify_commandline_options(parser, is_train): |
13 |
| - return parser |
| 10 | + """A dataset class for paired image dataset. |
| 11 | +
|
| 12 | + It assumes that the directory '/path/to/data/train' contains image pairs in the form of {A,B}. |
| 13 | + During test time, you need to prepare a directory /path/to/data/test. |
| 14 | + """ |
14 | 15 |
|
15 | 16 | def __init__(self, opt):
|
| 17 | + """Initialize this dataset class.""" |
16 | 18 | BaseDataset.__init__(self, opt)
|
17 |
| - self.dir_AB = os.path.join(opt.dataroot, opt.phase) |
18 |
| - self.AB_paths = sorted(make_dataset(self.dir_AB)) |
19 |
| - assert(opt.resize_or_crop == 'resize_and_crop') # only support this mode |
20 |
| - assert(self.opt.load_size >= self.opt.crop_size) |
| 19 | + self.dir_AB = os.path.join(opt.dataroot, opt.phase) # get the image directory |
| 20 | + self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size)) # get image paths |
| 21 | + assert(opt.resize_or_crop == 'resize_and_crop') # only support this mode |
| 22 | + assert(self.opt.load_size >= self.opt.crop_size) # crop_size should be smaller than the size of loaded image |
21 | 23 | input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
|
22 | 24 | output_nc = self.opt.input_nc if self.opt.direction == 'BtoA' else self.opt.output_nc
|
23 |
| - self.transform_A = get_simple_transform(grayscale=(input_nc == 1)) |
24 |
| - self.transform_B = get_simple_transform(grayscale=(output_nc == 1)) |
| 25 | + # we manually crop and flip in __getitem__ to make sure we apply the same crop and flip for image A and B |
| 26 | + # we disable the cropping and flipping in the function get_transform |
| 27 | + self.transform_A = get_transform(opt, grayscale=(input_nc == 1), crop=False, flip=False) |
| 28 | + self.transform_B = get_transform(opt, grayscale=(output_nc == 1), crop=False, flip=False) |
25 | 29 |
|
26 | 30 | def __getitem__(self, index):
|
| 31 | + """Return a data point and its metadata information. |
| 32 | +
|
| 33 | + Parameters: |
| 34 | + index - - a random integer for data indexing |
| 35 | +
|
| 36 | + Returns a dictionary of A, B, A_paths and B_paths |
| 37 | + A(tensor) - - an image in the input domain |
| 38 | + B(tensor) - - its corresponding image in the target domain |
| 39 | + A_paths(str) - - image paths |
| 40 | + B_paths(str) - - image paths |
| 41 | + """ |
| 42 | + # read a image given a random integer index |
27 | 43 | AB_path = self.AB_paths[index]
|
28 | 44 | AB = Image.open(AB_path).convert('RGB')
|
| 45 | + # split AB image into A and B |
29 | 46 | w, h = AB.size
|
30 | 47 | w2 = int(w / 2)
|
31 |
| - A0 = AB.crop((0, 0, w2, h)).resize((self.opt.load_size, self.opt.load_size), Image.BICUBIC) |
32 |
| - B0 = AB.crop((w2, 0, w, h)).resize((self.opt.load_size, self.opt.load_size), Image.BICUBIC) |
33 |
| - x, y, h, w = transforms.RandomCrop.get_params(A0, output_size=[self.opt.crop_size, self.opt.crop_size]) |
34 |
| - A = TF.crop(A0, x, y, h, w) |
35 |
| - B = TF.crop(B0, x, y, h, w) |
36 |
| - |
| 48 | + A = AB.crop((0, 0, w2, h)).resize((self.opt.load_size, self.opt.load_size), Image.BICUBIC) |
| 49 | + B = AB.crop((w2, 0, w, h)).resize((self.opt.load_size, self.opt.load_size), Image.BICUBIC) |
| 50 | + # apply the same cropping to both A and B |
| 51 | + if 'crop' in self.opt.resize_or_crop: |
| 52 | + x, y, h, w = transforms.RandomCrop.get_params(A, output_size=[self.opt.crop_size, self.opt.crop_size]) |
| 53 | + A = A.crop((x, y, w, h)) |
| 54 | + B = B.crop((x, y, w, h)) |
| 55 | + # apply the same flipping to both A and B |
37 | 56 | if (not self.opt.no_flip) and random.random() < 0.5:
|
38 |
| - A = TF.hflip(A) |
39 |
| - B = TF.hflip(B) |
| 57 | + A = A.transpose(Image.FLIP_LEFT_RIGHT) |
| 58 | + B = B.transpose(Image.FLIP_LEFT_RIGHT) |
| 59 | + # call standard transformation function |
40 | 60 | A = self.transform_A(A)
|
41 | 61 | B = self.transform_B(B)
|
| 62 | + print(AB_path, index) |
42 | 63 | return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path}
|
43 | 64 |
|
44 | 65 | def __len__(self):
|
| 66 | + """Return the total number of images.""" |
45 | 67 | return len(self.AB_paths)
|
0 commit comments