Skip to content

Commit 9341b95

Browse files
committed
change victorianDataset for ch3
1 parent b4a55bb commit 9341b95

File tree

1 file changed

+104
-55
lines changed

1 file changed

+104
-55
lines changed

book/chapters/GAN/Ch3-GAN.ipynb

Lines changed: 104 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,10 @@
3939
"import torch.utils.data as data\n",
4040
"import torchvision.datasets as datasets\n",
4141
"import torchvision.models as models\n",
42+
"from torch.utils.data import Dataset, DataLoader\n",
4243
"from PIL import Image\n",
43-
"import PIL\n"
44+
"import PIL\n",
45+
"import glob"
4446
]
4547
},
4648
{
@@ -51,12 +53,7 @@
5153
"source": [
5254
"gpu = 0\n",
5355
"batch_size = 32\n",
54-
"imsz = 64\n",
55-
"max_epoch = 500\n",
56-
"\n",
57-
"data_dir = '../GANdataset/victorian/'\n",
58-
"resized_dir = data_dir + 'resized/'\n",
59-
"test_dir = data_dir + 'test/'"
56+
"max_epoch = 10"
6057
]
6158
},
6259
{
@@ -65,28 +62,56 @@
6562
"metadata": {},
6663
"outputs": [],
6764
"source": [
68-
"# class victorianDataset(object):\n",
69-
"# def __init__(self, path):\n",
70-
"# self.path = path\n",
71-
"# self.imgs = list(sorted(os.listdir(self.path)))\n",
65+
"class victorianDataset(Dataset):\n",
66+
" def __init__(self, root, transforms_=None):\n",
67+
" self.transform = transforms.Compose(transforms_)\n",
7268
"\n",
69+
" self.gray_files = sorted(glob.glob(os.path.join(root, 'gray') + \"/*.*\"))\n",
70+
" self.color_files = sorted(glob.glob(os.path.join(root, 'resized') + \"/*.*\"))\n",
71+
" \n",
72+
" def __getitem__(self, index):\n",
7373
"\n",
74-
"# def __getitem__(self, idx):\n",
75-
"# file_image = self.imgs[idx]\n",
76-
"# img_path = os.path.join(self.path, file_image)\n",
74+
" gray_img = Image.open(self.gray_files[index % len(self.gray_files)]).convert(\"RGB\")\n",
75+
" color_img = Image.open(self.color_files[index % len(self.color_files)]).convert(\"RGB\")\n",
76+
" \n",
77+
" gray_img = self.transform(gray_img)\n",
78+
" color_img = self.transform(color_img)\n",
7779
"\n",
78-
"# mean = np.array([0.485, 0.456, 0.406])\n",
79-
"# std = np.array([0.229, 0.224, 0.225])\n",
80+
" return {\"A\": gray_img, \"B\": color_img}\n",
8081
"\n",
81-
"# image = Image.open(img_path).convert(\"RGB\")\n",
82-
"# image = std * image + mean\n",
83-
"# input_gray = image\n",
84-
"# input_gray = np.dot(input_gray[...,:3], [0.299, 0.587, 0.114])\n",
82+
" def __len__(self):\n",
83+
" return len(self.gray_files)\n"
84+
]
85+
},
86+
{
87+
"cell_type": "code",
88+
"execution_count": null,
89+
"metadata": {},
90+
"outputs": [],
91+
"source": [
92+
"root = '../GANdataset/victorian/'\n",
93+
"img_height = 256\n",
94+
"img_width = 256"
95+
]
96+
},
97+
{
98+
"cell_type": "code",
99+
"execution_count": null,
100+
"metadata": {},
101+
"outputs": [],
102+
"source": [
103+
"transforms_ = [\n",
104+
" # transforms.Resize((img_height, img_width), Image.BICUBIC),\n",
105+
" transforms.ToTensor(),\n",
106+
" # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
107+
"]\n",
85108
"\n",
86-
"# return torch.FloatTensor(input_gray.transpose((0,1))), torch.FloatTensor(image.transpose((2,0,1)))\n",
87109
"\n",
88-
"# def __len__(self):\n",
89-
"# return len(self.imgs)\n"
110+
"train_loader = DataLoader(\n",
111+
" victorianDataset(root, transforms_=transforms_),\n",
112+
" batch_size=batch_size,\n",
113+
" shuffle=True\n",
114+
")"
90115
]
91116
},
92117
{
@@ -95,12 +120,22 @@
95120
"metadata": {},
96121
"outputs": [],
97122
"source": [
98-
"train_dataset = victorianDataset(resized_dir)\n",
99-
"train_loader = data.DataLoader(train_dataset,\n",
100-
" batch_size=batch_size,\n",
101-
" shuffle=True,\n",
102-
" # num_workers=4,\n",
103-
" pin_memory=True)"
123+
"fig = plt.figure(figsize=(10, 5))\n",
124+
"rows = 1 \n",
125+
"cols = 2\n",
126+
"\n",
127+
"for X in train_loader:\n",
128+
" \n",
129+
" ax1 = fig.add_subplot(rows, cols, 1)\n",
130+
" ax1.imshow(np.clip(np.transpose(X[\"A\"][0], (1,2,0)), 0, 1))\n",
131+
" ax1.set_title('gray img')\n",
132+
"\n",
133+
" ax2 = fig.add_subplot(rows, cols, 2)\n",
134+
" ax2.imshow(np.clip(np.transpose(X[\"B\"][0], (1,2,0)), 0, 1))\n",
135+
" ax2.set_title('color img') \n",
136+
"\n",
137+
" plt.show()\n",
138+
" break"
104139
]
105140
},
106141
{
@@ -109,20 +144,39 @@
109144
"metadata": {},
110145
"outputs": [],
111146
"source": [
112-
"test_dataset = victorianDataset(test_dir)\n",
113-
"test_loader = data.DataLoader(test_dataset,\n",
114-
" batch_size=batch_size,\n",
115-
" shuffle=True,\n",
116-
" # num_workers=4,\n",
117-
" pin_memory=True)"
147+
"test_root = root + 'test/'\n",
148+
"test_batch_size = 6\n",
149+
"\n",
150+
"test_loader = DataLoader(\n",
151+
" victorianDataset(test_root, transforms_=transforms_),\n",
152+
" batch_size=test_batch_size,\n",
153+
" shuffle=True\n",
154+
")"
118155
]
119156
},
120157
{
121158
"cell_type": "code",
122159
"execution_count": null,
123160
"metadata": {},
124161
"outputs": [],
125-
"source": []
162+
"source": [
163+
"fig = plt.figure(figsize=(10, 5))\n",
164+
"rows = 1 \n",
165+
"cols = 2\n",
166+
"\n",
167+
"for X in test_loader:\n",
168+
" \n",
169+
" ax1 = fig.add_subplot(rows, cols, 1)\n",
170+
" ax1.imshow(np.clip(np.transpose(X[\"A\"][0], (1,2,0)), 0, 1))\n",
171+
" ax1.set_title('gray img')\n",
172+
"\n",
173+
" ax2 = fig.add_subplot(rows, cols, 2)\n",
174+
" ax2.imshow(np.clip(np.transpose(X[\"B\"][0], (1,2,0)), 0, 1))\n",
175+
" ax2.set_title('color img') \n",
176+
"\n",
177+
" plt.show()\n",
178+
" break"
179+
]
126180
},
127181
{
128182
"cell_type": "code",
@@ -134,12 +188,7 @@
134188
" inp = inp.numpy().transpose((1,2,0))\n",
135189
" print(inp.shape)\n",
136190
" inp = np.clip(inp, 0, 1) \n",
137-
" plt.imshow(inp)\n",
138-
"\n",
139-
"def gray_imshow(inp):\n",
140-
" inp = inp.numpy()#.transpose((1,2,0))\n",
141-
" print(inp.shape)\n",
142-
" plt.imshow(inp,cmap = plt.get_cmap('gray'))\n"
191+
" plt.imshow(inp)"
143192
]
144193
},
145194
{
@@ -383,11 +432,11 @@
383432
"for epoch in range(max_epoch):\n",
384433
" loss_D = 0.0\n",
385434
" for i, data in enumerate(train_loader):\n",
386-
" gray, color = data\n",
387-
" #print(len(data[0]))\n",
388-
" b_size = len(data[0])\n",
435+
" gray, color = data['A'], data['B']\n",
436+
" # print(len(data['A']))\n",
437+
" b_size = len(data['A'])\n",
389438
"\n",
390-
" color = torch.from_numpy(np.resize(color.numpy(), (b_size, 3, 64, 64))) ### 위에서 normalize 안해서\n",
439+
" color = torch.from_numpy(np.resize(color.numpy(), (b_size, 3, 64, 64))) \n",
391440
" # gray >> grays (batch_size * 1 * 64 * 64)\n",
392441
" grays = torch.from_numpy(np.resize(gray.numpy(), (b_size, 1, 64, 64)))\n",
393442
" \n",
@@ -441,8 +490,7 @@
441490
" #print(fake.shape)\n",
442491
" fake_img = torchvision.utils.make_grid(fake_img.data)\n",
443492
"\n",
444-
"\n",
445-
" if (epoch + 1) % 50 == 0:\n",
493+
" if (epoch + 1) % 2 == 0:\n",
446494
" print('[%d, %5d] real loss: %.4f, fake_loss : %.4f, g_loss : %.4f' % (epoch + 1, i+1, real_loss.item(),fake_loss.item(), g_loss.item()))\n",
447495
" imshow(fake_img.cpu())\n",
448496
" plt.show()"
@@ -457,19 +505,20 @@
457505
"Discri.eval()\n",
458506
"Gener.eval()\n",
459507
"\n",
460-
"fixed_noise = torch.randn(batch_size, 1, 64, 64).uniform_(0,1)\n",
508+
"fixed_noise = torch.randn(test_batch_size, 1, 64, 64).uniform_(0,1)\n",
461509
"\n",
462510
"for i, data in enumerate(test_loader,0) :\n",
463-
" images, label = data\n",
464-
" \n",
465-
" if len(data[0]) != batch_size:\n",
511+
" images, label = data['A'], data['B']\n",
512+
"\n",
513+
" if len(data['A']) != test_batch_size:\n",
466514
" continue\n",
467-
" \n",
468-
" grays = torch.from_numpy(np.resize(images.numpy(), (batch_size, 1, 64, 64)))\n",
515+
"\n",
516+
" grays = torch.from_numpy(np.resize(images.numpy(), (test_batch_size, 1, 64, 64)))\n",
517+
" print(grays.shape)\n",
469518
" \n",
470519
" gray = to_variable(torch.cat([grays,fixed_noise],dim = 1))\n",
471520
" \n",
472-
" output = Gener(gray)\n",
521+
" # output = Gener(gray)\n",
473522
" inputs = torchvision.utils.make_grid(grays)\n",
474523
" labels = torchvision.utils.make_grid(label)\n",
475524
" out = torchvision.utils.make_grid(output.data)\n",

0 commit comments

Comments
 (0)