Skip to content

Commit 9847f36

Browse files
committed
initial code
1 parent 26dd6b3 commit 9847f36

File tree

1 file changed

+235
-0
lines changed

1 file changed

+235
-0
lines changed

SP21/GAN/vanilla_gan.ipynb

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
{
2+
"metadata": {
3+
"language_info": {
4+
"codemirror_mode": {
5+
"name": "ipython",
6+
"version": 3
7+
},
8+
"file_extension": ".py",
9+
"mimetype": "text/x-python",
10+
"name": "python",
11+
"nbconvert_exporter": "python",
12+
"pygments_lexer": "ipython3",
13+
"version": "3.7.9"
14+
},
15+
"orig_nbformat": 2,
16+
"kernelspec": {
17+
"name": "python379jvsc74a57bd0c3e162dbea2a7bde87c7844ddb2ca1843aba08611dc435aaff358cdc82c48661",
18+
"display_name": "Python 3.7.9 64-bit ('three-d': conda)"
19+
}
20+
},
21+
"nbformat": 4,
22+
"nbformat_minor": 2,
23+
"cells": [
24+
{
25+
"cell_type": "code",
26+
"execution_count": 5,
27+
"metadata": {},
28+
"outputs": [],
29+
"source": [
30+
"# Import all the necessary libraries\n",
31+
"import numpy as np\n",
32+
"import torch.nn as nn\n",
33+
"import torch.nn.functional as F\n",
34+
"import torch\n",
35+
"import torchvision\n",
36+
"import matplotlib.pyplot as plt\n",
37+
"from tqdm import notebook"
38+
]
39+
},
40+
{
41+
"cell_type": "code",
42+
"execution_count": 6,
43+
"metadata": {},
44+
"outputs": [
45+
{
46+
"output_type": "error",
47+
"ename": "SyntaxError",
48+
"evalue": "invalid syntax (<ipython-input-6-df9263a4d089>, line 25)",
49+
"traceback": [
50+
"\u001b[0;36m File \u001b[0;32m\"<ipython-input-6-df9263a4d089>\"\u001b[0;36m, line \u001b[0;32m25\u001b[0m\n\u001b[0;31m nn.Sigmoid()\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n"
51+
]
52+
}
53+
],
54+
"source": [
55+
"# Define our simple vanilla generator\n",
56+
"class Generator(nn.Module):\n",
57+
" \"\"\"\n",
58+
" Architecture\n",
59+
" ------------\n",
60+
" Latent Input: latent_shape\n",
61+
" Flattened\n",
62+
" Linear MLP(256, 512, 1024, prod(img_shape))\n",
63+
"\n",
64+
" Leaky Relu activation after every layer except last. (Important!)\n",
65+
" Tanh activation after last layer to normalize\n",
66+
" \"\"\"\n",
67+
" def __init__(self, latent_shape, img_shape):\n",
68+
" super(Generator, self).__init__()\n",
69+
" self.img_shape = img_shape\n",
70+
" self.mlp = nn.Sequential(\n",
71+
" nn.Flatten(),\n",
72+
" nn.Linear(np.prod(latent_shape), 256),\n",
73+
" nn.LeakyReLU(0.2),\n",
74+
" nn.Linear(256, 512),\n",
75+
" nn.LeakyReLU(0.2),\n",
76+
" nn.Linear(512, 1024),\n",
77+
" nn.LeakyReLU(0.2),\n",
78+
" nn.Linear(1024, np.prod(img_shape)),\n",
79+
" nn.Tanh()\n",
80+
" )\n",
81+
" def forward(self, x):\n",
82+
" batch_size = x.shape[0]\n",
83+
" # reshape into a image\n",
84+
" return self.mlp(x).reshape(batch_size, 1, *self.img_shape)"
85+
]
86+
},
87+
{
88+
"cell_type": "code",
89+
"execution_count": 7,
90+
"metadata": {},
91+
"outputs": [],
92+
"source": [
93+
"# Define our simple vanilla discriminator\n",
94+
"class Discriminator(nn.Module):\n",
95+
" \"\"\"\n",
96+
" Architecture\n",
97+
" ------------\n",
98+
" Input Image: img_shape\n",
99+
" Flattened\n",
100+
" Linear MLP(128, 512, 256, 1)\n",
101+
" Relu activation after every layer except last.\n",
102+
" Sigmoid activation after last layer to normalize in range 0 to 1\n",
103+
" \"\"\"\n",
104+
" def __init__(self, img_shape):\n",
105+
" super(Discriminator, self).__init__()\n",
106+
"\n",
107+
" self.mlp = nn.Sequential(\n",
108+
" nn.Flatten(),\n",
109+
" nn.Linear(np.prod(img_shape), 128),\n",
110+
" nn.ReLU(),\n",
111+
" nn.Linear(128, 512),\n",
112+
" nn.ReLU(),\n",
113+
" nn.Linear(512, 256),\n",
114+
" nn.ReLU(),\n",
115+
" nn.Linear(256, 1),\n",
116+
" nn.Sigmoid()\n",
117+
" )\n",
118+
" def forward(self, x):\n",
119+
" return self.mlp(x)"
120+
]
121+
},
122+
{
123+
"cell_type": "code",
124+
"execution_count": null,
125+
"metadata": {},
126+
"outputs": [],
127+
"source": [
128+
"# load our data\n",
129+
"latent_shape = (28, 28)\n",
130+
"img_shape = (28, 28)\n",
131+
"batch_size = 64\n",
132+
"\n",
133+
" transform = torchvision.transforms.Compose(\n",
134+
" [\n",
135+
" torchvision.transforms.ToTensor() # converts the PIL Image format to a pytorch tensor\n",
136+
" ]\n",
137+
")\n",
138+
"train_dataset = torchvision.datasets.MNIST(root=\"./data\", train = True, download=True, transform=transform)\n",
139+
"train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)"
140+
]
141+
},
142+
{
143+
"cell_type": "code",
144+
"execution_count": null,
145+
"metadata": {},
146+
"outputs": [],
147+
"source": [
148+
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # for gpu usage if possible\n",
149+
"\n",
150+
"generator = Generator(latent_shape, img_shape)\n",
151+
"discriminator = Discriminator(img_shape)\n",
152+
"\n",
153+
"gen_optim = torch.optim.Adam(generator.parameters(), lr=2e-4)\n",
154+
"disc_optim = torch.optim.Adam(discriminator.parameters(), lr=2e-4)\n",
155+
"\n",
156+
"# use gpu if possible\n",
157+
"generator = generator.to(device)\n",
158+
"discriminator = discriminator.to(device)"
159+
]
160+
},
161+
{
162+
"cell_type": "code",
163+
"execution_count": null,
164+
"metadata": {},
165+
"outputs": [],
166+
"source": [
167+
"def train(generator, discriminator, generator_optim: torch.optim, discriminator_optim: torch.optim, epochs=10):\n",
168+
" adversarial_loss = torch.nn.BCELoss()\n",
169+
" \n",
170+
" for epoch in range(1, epochs+1):\n",
171+
" print(\"Epoch {}\".format(epoch))\n",
172+
" avg_g_loss = 0\n",
173+
" avg_d_loss = 0\n",
174+
" pbar = notebook.tqdm(train_dataloader, total=len(train_dataloader))\n",
175+
" i = 0\n",
176+
" for data in pbar:\n",
177+
" i += 1\n",
178+
" real_images = data[0].to(device)\n",
179+
" ### Train Generator ###\n",
180+
" generator_optim.zero_grad()\n",
181+
" \n",
182+
" latent_input = torch.randn((batch_size, 1, *latent_shape)).to(device)\n",
183+
" fake_images = generator(latent_input)\n",
184+
"\n",
185+
" fake_res = discriminator(fake_images)\n",
186+
" \n",
187+
" generator_loss = adversarial_loss(fake_res, torch.ones_like(fake_res))\n",
188+
" generator_loss.backward()\n",
189+
" generator_optim.step()\n",
190+
" \n",
191+
" ### Train Discriminator ###\n",
192+
" discriminator_optim.zero_grad()\n",
193+
" \n",
194+
" real_res = discriminator(real_images)\n",
195+
"\n",
196+
" fake_res = discriminator(fake_images.detach())\n",
197+
"\n",
198+
" discriminator_real_loss = adversarial_loss(real_res, torch.ones_like(real_res))\n",
199+
" discriminator_fake_loss = adversarial_loss(fake_res, torch.zeros_like(real_res))\n",
200+
" discriminator_loss = (discriminator_real_loss + discriminator_fake_loss) / 2\n",
201+
" discriminator_loss.backward()\n",
202+
" discriminator_optim.step()\n",
203+
" \n",
204+
"\n",
205+
" avg_g_loss += generator_loss.item()\n",
206+
" avg_d_loss += discriminator_loss.item()\n",
207+
" pbar.set_postfix({\"G_loss\": generator_loss.item(), \"D_loss\": discriminator_loss.item()})"
208+
]
209+
},
210+
{
211+
"cell_type": "code",
212+
"execution_count": null,
213+
"metadata": {},
214+
"outputs": [],
215+
"source": [
216+
"# train our generator and discriminator\n",
217+
"# Note: don't always expect loss to go down simultaneously for both models. They are competing against each other! So sometimes one model \n",
218+
"# may perform better than the other\n",
219+
"train(generator=generator, discriminator=discriminator, generator_optim=gen_optim, discriminator_optim=disc_optim)"
220+
]
221+
},
222+
{
223+
"cell_type": "code",
224+
"execution_count": null,
225+
"metadata": {},
226+
"outputs": [],
227+
"source": [
228+
"# test it out!\n",
229+
"latent_input = torch.randn((batch_size, 1, *latent_shape))\n",
230+
"test = generator(latent_input.to(device))\n",
231+
"plt.imshow(test[0].reshape(28, 28).cpu().detach().numpy())"
232+
]
233+
}
234+
]
235+
}

0 commit comments

Comments
 (0)