|
39 | 39 | "import torch.utils.data as data\n", |
40 | 40 | "import torchvision.datasets as datasets\n", |
41 | 41 | "import torchvision.models as models\n", |
| 42 | + "from torch.utils.data import Dataset, DataLoader\n", |
42 | 43 | "from PIL import Image\n", |
43 | | - "import PIL\n" |
| 44 | + "import PIL\n", |
| 45 | + "import glob" |
44 | 46 | ] |
45 | 47 | }, |
46 | 48 | { |
|
51 | 53 | "source": [ |
52 | 54 | "gpu = 0\n", |
53 | 55 | "batch_size = 32\n", |
54 | | - "imsz = 64\n", |
55 | | - "max_epoch = 500\n", |
56 | | - "\n", |
57 | | - "data_dir = '../GANdataset/victorian/'\n", |
58 | | - "resized_dir = data_dir + 'resized/'\n", |
59 | | - "test_dir = data_dir + 'test/'" |
| 56 | + "max_epoch = 10" |
60 | 57 | ] |
61 | 58 | }, |
62 | 59 | { |
|
65 | 62 | "metadata": {}, |
66 | 63 | "outputs": [], |
67 | 64 | "source": [ |
68 | | - "# class victorianDataset(object):\n", |
69 | | - "# def __init__(self, path):\n", |
70 | | - "# self.path = path\n", |
71 | | - "# self.imgs = list(sorted(os.listdir(self.path)))\n", |
| 65 | + "class victorianDataset(Dataset):\n", |
| 66 | + " def __init__(self, root, transforms_=None):\n", |
| 67 | + " self.transform = transforms.Compose(transforms_)\n", |
72 | 68 | "\n", |
| 69 | + " self.gray_files = sorted(glob.glob(os.path.join(root, 'gray') + \"/*.*\"))\n", |
| 70 | + " self.color_files = sorted(glob.glob(os.path.join(root, 'resized') + \"/*.*\"))\n", |
| 71 | + " \n", |
| 72 | + " def __getitem__(self, index):\n", |
73 | 73 | "\n", |
74 | | - "# def __getitem__(self, idx):\n", |
75 | | - "# file_image = self.imgs[idx]\n", |
76 | | - "# img_path = os.path.join(self.path, file_image)\n", |
| 74 | + " gray_img = Image.open(self.gray_files[index % len(self.gray_files)]).convert(\"RGB\")\n", |
| 75 | + " color_img = Image.open(self.color_files[index % len(self.color_files)]).convert(\"RGB\")\n", |
| 76 | + " \n", |
| 77 | + " gray_img = self.transform(gray_img)\n", |
| 78 | + " color_img = self.transform(color_img)\n", |
77 | 79 | "\n", |
78 | | - "# mean = np.array([0.485, 0.456, 0.406])\n", |
79 | | - "# std = np.array([0.229, 0.224, 0.225])\n", |
| 80 | + " return {\"A\": gray_img, \"B\": color_img}\n", |
80 | 81 | "\n", |
81 | | - "# image = Image.open(img_path).convert(\"RGB\")\n", |
82 | | - "# image = std * image + mean\n", |
83 | | - "# input_gray = image\n", |
84 | | - "# input_gray = np.dot(input_gray[...,:3], [0.299, 0.587, 0.114])\n", |
| 82 | + " def __len__(self):\n", |
| 83 | + " return len(self.gray_files)\n" |
| 84 | + ] |
| 85 | + }, |
| 86 | + { |
| 87 | + "cell_type": "code", |
| 88 | + "execution_count": null, |
| 89 | + "metadata": {}, |
| 90 | + "outputs": [], |
| 91 | + "source": [ |
| 92 | + "root = '../GANdataset/victorian/'\n", |
| 93 | + "img_height = 256\n", |
| 94 | + "img_width = 256" |
| 95 | + ] |
| 96 | + }, |
| 97 | + { |
| 98 | + "cell_type": "code", |
| 99 | + "execution_count": null, |
| 100 | + "metadata": {}, |
| 101 | + "outputs": [], |
| 102 | + "source": [ |
| 103 | + "transforms_ = [\n", |
| 104 | + " # transforms.Resize((img_height, img_width), Image.BICUBIC),\n", |
| 105 | + " transforms.ToTensor(),\n", |
| 106 | + " # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", |
| 107 | + "]\n", |
85 | 108 | "\n", |
86 | | - "# return torch.FloatTensor(input_gray.transpose((0,1))), torch.FloatTensor(image.transpose((2,0,1)))\n", |
87 | 109 | "\n", |
88 | | - "# def __len__(self):\n", |
89 | | - "# return len(self.imgs)\n" |
| 110 | + "train_loader = DataLoader(\n", |
| 111 | + " victorianDataset(root, transforms_=transforms_),\n", |
| 112 | + " batch_size=batch_size,\n", |
| 113 | + " shuffle=True\n", |
| 114 | + ")" |
90 | 115 | ] |
91 | 116 | }, |
92 | 117 | { |
|
95 | 120 | "metadata": {}, |
96 | 121 | "outputs": [], |
97 | 122 | "source": [ |
98 | | - "train_dataset = victorianDataset(resized_dir)\n", |
99 | | - "train_loader = data.DataLoader(train_dataset,\n", |
100 | | - " batch_size=batch_size,\n", |
101 | | - " shuffle=True,\n", |
102 | | - " # num_workers=4,\n", |
103 | | - " pin_memory=True)" |
| 123 | + "fig = plt.figure(figsize=(10, 5))\n", |
| 124 | + "rows = 1 \n", |
| 125 | + "cols = 2\n", |
| 126 | + "\n", |
| 127 | + "for X in train_loader:\n", |
| 128 | + " \n", |
| 129 | + " ax1 = fig.add_subplot(rows, cols, 1)\n", |
| 130 | + " ax1.imshow(np.clip(np.transpose(X[\"A\"][0], (1,2,0)), 0, 1))\n", |
| 131 | + " ax1.set_title('gray img')\n", |
| 132 | + "\n", |
| 133 | + " ax2 = fig.add_subplot(rows, cols, 2)\n", |
| 134 | + " ax2.imshow(np.clip(np.transpose(X[\"B\"][0], (1,2,0)), 0, 1))\n", |
| 135 | + " ax2.set_title('color img') \n", |
| 136 | + "\n", |
| 137 | + " plt.show()\n", |
| 138 | + " break" |
104 | 139 | ] |
105 | 140 | }, |
106 | 141 | { |
|
109 | 144 | "metadata": {}, |
110 | 145 | "outputs": [], |
111 | 146 | "source": [ |
112 | | - "test_dataset = victorianDataset(test_dir)\n", |
113 | | - "test_loader = data.DataLoader(test_dataset,\n", |
114 | | - " batch_size=batch_size,\n", |
115 | | - " shuffle=True,\n", |
116 | | - " # num_workers=4,\n", |
117 | | - " pin_memory=True)" |
| 147 | + "test_root = root + 'test/'\n", |
| 148 | + "test_batch_size = 6\n", |
| 149 | + "\n", |
| 150 | + "test_loader = DataLoader(\n", |
| 151 | + " victorianDataset(test_root, transforms_=transforms_),\n", |
| 152 | + " batch_size=test_batch_size,\n", |
| 153 | + " shuffle=True\n", |
| 154 | + ")" |
118 | 155 | ] |
119 | 156 | }, |
120 | 157 | { |
121 | 158 | "cell_type": "code", |
122 | 159 | "execution_count": null, |
123 | 160 | "metadata": {}, |
124 | 161 | "outputs": [], |
125 | | - "source": [] |
| 162 | + "source": [ |
| 163 | + "fig = plt.figure(figsize=(10, 5))\n", |
| 164 | + "rows = 1 \n", |
| 165 | + "cols = 2\n", |
| 166 | + "\n", |
| 167 | + "for X in test_loader:\n", |
| 168 | + " \n", |
| 169 | + " ax1 = fig.add_subplot(rows, cols, 1)\n", |
| 170 | + " ax1.imshow(np.clip(np.transpose(X[\"A\"][0], (1,2,0)), 0, 1))\n", |
| 171 | + " ax1.set_title('gray img')\n", |
| 172 | + "\n", |
| 173 | + " ax2 = fig.add_subplot(rows, cols, 2)\n", |
| 174 | + " ax2.imshow(np.clip(np.transpose(X[\"B\"][0], (1,2,0)), 0, 1))\n", |
| 175 | + " ax2.set_title('color img') \n", |
| 176 | + "\n", |
| 177 | + " plt.show()\n", |
| 178 | + " break" |
| 179 | + ] |
126 | 180 | }, |
127 | 181 | { |
128 | 182 | "cell_type": "code", |
|
134 | 188 | " inp = inp.numpy().transpose((1,2,0))\n", |
135 | 189 | " print(inp.shape)\n", |
136 | 190 | " inp = np.clip(inp, 0, 1) \n", |
137 | | - " plt.imshow(inp)\n", |
138 | | - "\n", |
139 | | - "def gray_imshow(inp):\n", |
140 | | - " inp = inp.numpy()#.transpose((1,2,0))\n", |
141 | | - " print(inp.shape)\n", |
142 | | - " plt.imshow(inp,cmap = plt.get_cmap('gray'))\n" |
| 191 | + " plt.imshow(inp)" |
143 | 192 | ] |
144 | 193 | }, |
145 | 194 | { |
|
383 | 432 | "for epoch in range(max_epoch):\n", |
384 | 433 | " loss_D = 0.0\n", |
385 | 434 | " for i, data in enumerate(train_loader):\n", |
386 | | - " gray, color = data\n", |
387 | | - " #print(len(data[0]))\n", |
388 | | - " b_size = len(data[0])\n", |
| 435 | + " gray, color = data['A'], data['B']\n", |
| 436 | + " # print(len(data['A']))\n", |
| 437 | + " b_size = len(data['A'])\n", |
389 | 438 | "\n", |
390 | | - " color = torch.from_numpy(np.resize(color.numpy(), (b_size, 3, 64, 64))) ### 위에서 normalize 안해서\n", |
| 439 | + " color = torch.from_numpy(np.resize(color.numpy(), (b_size, 3, 64, 64))) \n", |
391 | 440 | " # gray >> grays (batch_size * 1 * 64 * 64)\n", |
392 | 441 | " grays = torch.from_numpy(np.resize(gray.numpy(), (b_size, 1, 64, 64)))\n", |
393 | 442 | " \n", |
|
441 | 490 | " #print(fake.shape)\n", |
442 | 491 | " fake_img = torchvision.utils.make_grid(fake_img.data)\n", |
443 | 492 | "\n", |
444 | | - "\n", |
445 | | - " if (epoch + 1) % 50 == 0:\n", |
| 493 | + " if (epoch + 1) % 2 == 0:\n", |
446 | 494 | " print('[%d, %5d] real loss: %.4f, fake_loss : %.4f, g_loss : %.4f' % (epoch + 1, i+1, real_loss.item(),fake_loss.item(), g_loss.item()))\n", |
447 | 495 | " imshow(fake_img.cpu())\n", |
448 | 496 | " plt.show()" |
|
457 | 505 | "Discri.eval()\n", |
458 | 506 | "Gener.eval()\n", |
459 | 507 | "\n", |
460 | | - "fixed_noise = torch.randn(batch_size, 1, 64, 64).uniform_(0,1)\n", |
| 508 | + "fixed_noise = torch.randn(test_batch_size, 1, 64, 64).uniform_(0,1)\n", |
461 | 509 | "\n", |
462 | 510 | "for i, data in enumerate(test_loader,0) :\n", |
463 | | - " images, label = data\n", |
464 | | - " \n", |
465 | | - " if len(data[0]) != batch_size:\n", |
| 511 | + " images, label = data['A'], data['B']\n", |
| 512 | + "\n", |
| 513 | + " if len(data['A']) != test_batch_size:\n", |
466 | 514 | " continue\n", |
467 | | - " \n", |
468 | | - " grays = torch.from_numpy(np.resize(images.numpy(), (batch_size, 1, 64, 64)))\n", |
| 515 | + "\n", |
| 516 | + " grays = torch.from_numpy(np.resize(images.numpy(), (test_batch_size, 1, 64, 64)))\n", |
| 517 | + " print(grays.shape)\n", |
469 | 518 | " \n", |
470 | 519 | " gray = to_variable(torch.cat([grays,fixed_noise],dim = 1))\n", |
471 | 520 | " \n", |
472 | | - " output = Gener(gray)\n", |
| 521 | + " # output = Gener(gray)\n", |
473 | 522 | " inputs = torchvision.utils.make_grid(grays)\n", |
474 | 523 | " labels = torchvision.utils.make_grid(label)\n", |
475 | 524 | " out = torchvision.utils.make_grid(output.data)\n", |
|
0 commit comments