Skip to content

Commit c4d922c

Browse files
committed
Conditional gan and bug fixes
1 parent 70ed3d3 commit c4d922c

File tree

2 files changed

+264
-52
lines changed

2 files changed

+264
-52
lines changed

SP21/GAN/conditional_gan.ipynb

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
{
2+
"metadata": {
3+
"language_info": {
4+
"codemirror_mode": {
5+
"name": "ipython",
6+
"version": 3
7+
},
8+
"file_extension": ".py",
9+
"mimetype": "text/x-python",
10+
"name": "python",
11+
"nbconvert_exporter": "python",
12+
"pygments_lexer": "ipython3",
13+
"version": "3.7.9"
14+
},
15+
"orig_nbformat": 2,
16+
"kernelspec": {
17+
"name": "python379jvsc74a57bd0c3e162dbea2a7bde87c7844ddb2ca1843aba08611dc435aaff358cdc82c48661",
18+
"display_name": "Python 3.7.9 64-bit ('three-d': conda)"
19+
}
20+
},
21+
"nbformat": 4,
22+
"nbformat_minor": 2,
23+
"cells": [
24+
{
25+
"cell_type": "code",
26+
"execution_count": null,
27+
"metadata": {},
28+
"outputs": [],
29+
"source": [
30+
"# Import all the necessary libraries\n",
31+
"import numpy as np\n",
32+
"import torch.nn as nn\n",
33+
"import torch.nn.functional as F\n",
34+
"import torch\n",
35+
"import torchvision\n",
36+
"import matplotlib.pyplot as plt\n",
37+
"from tqdm import notebook"
38+
]
39+
},
40+
{
41+
"cell_type": "code",
42+
"execution_count": null,
43+
"metadata": {},
44+
"outputs": [],
45+
"source": [
46+
"class Generator(nn.Module):\n",
47+
" \"\"\"\n",
48+
" Architecture\n",
49+
" ------------\n",
50+
" Latent Input: latent_shape\n",
51+
" Flattened\n",
52+
" Linear MLP(256, 512, 1024, prod(img_shape))\n",
53+
"\n",
54+
" Leaky Relu activation after every layer except last. (Important!)\n",
55+
" Tanh activation after last layer to normalize\n",
56+
" \"\"\"\n",
57+
" def __init__(self, latent_shape, img_shape):\n",
58+
" super(Generator, self).__init__()\n",
59+
" self.img_shape = img_shape\n",
60+
" self.mlp = nn.Sequential(\n",
61+
" nn.Flatten(),\n",
62+
" nn.Linear(np.prod(latent_shape), 256),\n",
63+
" nn.LeakyReLU(0.2),\n",
64+
" nn.Linear(256, 512),\n",
65+
" nn.LeakyReLU(0.2),\n",
66+
" nn.Linear(512, 1024),\n",
67+
" nn.LeakyReLU(0.2),\n",
68+
" nn.Linear(1024, np.prod(img_shape)),\n",
69+
" nn.Tanh()\n",
70+
" )\n",
71+
" def forward(self, x):\n",
72+
" batch_size = x.shape[0]\n",
73+
" # reshape into a image\n",
74+
" return self.mlp(x).reshape(batch_size, 1, *self.img_shape)"
75+
]
76+
},
77+
{
78+
"cell_type": "code",
79+
"execution_count": null,
80+
"metadata": {},
81+
"outputs": [],
82+
"source": [
83+
"# Define our simple vanilla discriminator\n",
84+
"class Discriminator(nn.Module):\n",
85+
" \"\"\"\n",
86+
" Architecture\n",
87+
" ------------\n",
88+
" Input Image: img_shape\n",
89+
" Flattened\n",
90+
" Linear MLP(128, 512, 256, 1)\n",
91+
" Relu activation after every layer except last.\n",
92+
" Sigmoid activation after last layer to normalize in range 0 to 1\n",
93+
" \"\"\"\n",
94+
" def __init__(self, img_shape):\n",
95+
" super(Discriminator, self).__init__()\n",
96+
"\n",
97+
" self.mlp = nn.Sequential(\n",
98+
" nn.Flatten(),\n",
99+
" nn.Linear(np.prod(img_shape), 1024),\n",
100+
" nn.LeakyReLU(0.2),\n",
101+
" nn.Linear(1024, 512),\n",
102+
" nn.LeakyReLU(0.2),\n",
103+
" nn.Linear(512, 256),\n",
104+
" nn.LeakyReLU(0.2),\n",
105+
" nn.Linear(256, 1),\n",
106+
" nn.Sigmoid()\n",
107+
" )\n",
108+
" def forward(self, x):\n",
109+
" return self.mlp(x)"
110+
]
111+
},
112+
{
113+
"cell_type": "code",
114+
"execution_count": null,
115+
"metadata": {},
116+
"outputs": [],
117+
"source": [
118+
"# load our data\n",
119+
"latent_shape = (28, 28)\n",
120+
"img_shape = (28, 28)\n",
121+
"batch_size = 64\n",
122+
"\n",
123+
"transform = transforms.Compose([\n",
124+
" transforms.ToTensor(),\n",
125+
" transforms.Normalize(mean=(0.5), std=(0.5))])\n",
126+
"train_dataset = torchvision.datasets.MNIST(root=\"./data\", train = True, download=True, transform=transform)der(train_dataset, batch_size=batch_size, shuffle=True)"
127+
]
128+
},
129+
{
130+
"cell_type": "code",
131+
"execution_count": null,
132+
"metadata": {},
133+
"outputs": [],
134+
"source": [
135+
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # for gpu usage if possible\n",
136+
"\n",
137+
"generator = Generator(latent_shape, img_shape)\n",
138+
"discriminator = Discriminator(img_shape)\n",
139+
"\n",
140+
"gen_optim = torch.optim.Adam(generator.parameters(), lr=2e-4)\n",
141+
"disc_optim = torch.optim.Adam(discriminator.parameters(), lr=2e-4)\n",
142+
"\n",
143+
"# use gpu if possible\n",
144+
"generator = generator.to(device)\n",
145+
"discriminator = discriminator.to(device)"
146+
]
147+
},
148+
{
149+
"cell_type": "code",
150+
"execution_count": null,
151+
"metadata": {},
152+
"outputs": [],
153+
"source": [
154+
"def train(generator, discriminator, generator_optim: torch.optim, discriminator_optim: torch.optim, epochs=10):\n",
155+
" adversarial_loss = torch.nn.BCELoss()\n",
156+
" \n",
157+
" for epoch in range(1, epochs+1):\n",
158+
" print(\"Epoch {}\".format(epoch))\n",
159+
" avg_g_loss = 0\n",
160+
" avg_d_loss = 0\n",
161+
" pbar = notebook.tqdm(train_dataloader, total=len(train_dataloader))\n",
162+
" i = 0\n",
163+
" for data in pbar:\n",
164+
" i += 1\n",
165+
" real_images = data[0].to(device)\n",
166+
" ### Train Generator ###\n",
167+
" generator_optim.zero_grad()\n",
168+
" \n",
169+
" latent_input = torch.randn((batch_size, 1, *latent_shape)).to(device)\n",
170+
" fake_images = generator(latent_input)\n",
171+
"\n",
172+
" fake_res = discriminator(fake_images)\n",
173+
" \n",
174+
" generator_loss = adversarial_loss(fake_res, torch.ones_like(fake_res))\n",
175+
" generator_loss.backward()\n",
176+
" generator_optim.step()\n",
177+
" \n",
178+
" ### Train Discriminator ###\n",
179+
" discriminator_optim.zero_grad()\n",
180+
" \n",
181+
" real_res = discriminator(real_images)\n",
182+
"\n",
183+
" fake_res = discriminator(fake_images.detach())\n",
184+
"\n",
185+
" discriminator_real_loss = adversarial_loss(real_res, torch.ones_like(real_res))\n",
186+
" discriminator_fake_loss = adversarial_loss(fake_res, torch.zeros_like(real_res))\n",
187+
" discriminator_loss = (discriminator_real_loss + discriminator_fake_loss) / 2\n",
188+
" discriminator_loss.backward()\n",
189+
" discriminator_optim.step()\n",
190+
" \n",
191+
"\n",
192+
" avg_g_loss += generator_loss.item()\n",
193+
" avg_d_loss += discriminator_loss.item()\n",
194+
" pbar.set_postfix({\"G_loss\": generator_loss.item(), \"D_loss\": discriminator_loss.item()})"
195+
]
196+
},
197+
{
198+
"cell_type": "code",
199+
"execution_count": null,
200+
"metadata": {},
201+
"outputs": [],
202+
"source": [
203+
"# train our generator and discriminator\n",
204+
"# Note: don't always expect loss to go down simultaneously for both models. They are competing against each other! So sometimes one model \n",
205+
"# may perform better than the other\n",
206+
"train(generator=generator, discriminator=discriminator, generator_optim=gen_optim, discriminator_optim=disc_optim)"
207+
]
208+
},
209+
{
210+
"cell_type": "code",
211+
"execution_count": null,
212+
"metadata": {},
213+
"outputs": [],
214+
"source": [
215+
"# test it out!\n",
216+
"latent_input = torch.randn((batch_size, 1, *latent_shape))\n",
217+
"test = generator(latent_input.to(device))\n",
218+
"plt.imshow(test[0].reshape(28, 28).cpu().detach().numpy())"
219+
]
220+
}
221+
]
222+
}

0 commit comments

Comments
 (0)