Skip to content

Commit d0d6438

Browse files
committed
changes
1 parent c4d922c commit d0d6438

File tree

4 files changed

+570
-112
lines changed

4 files changed

+570
-112
lines changed
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"# Import all the necessary libraries\n",
10+
"import numpy as np\n",
11+
"import torch.nn as nn\n",
12+
"import torch.nn.functional as F\n",
13+
"import torch\n",
14+
"import torchvision\n",
15+
"import matplotlib.pyplot as plt\n",
16+
"from tqdm import notebook"
17+
]
18+
},
19+
{
20+
"cell_type": "code",
21+
"execution_count": null,
22+
"metadata": {},
23+
"outputs": [],
24+
"source": [
25+
"class Generator(nn.Module):\n",
26+
" def __init__(self, latent_shape, img_shape):\n",
27+
" super(Generator, self).__init__()\n",
28+
" self.img_shape = img_shape\n",
29+
" self.flatten = nn.Flatten()\n",
30+
" self.mlp = nn.Sequential(\n",
31+
" nn.Linear(np.prod(latent_shape) + 10, 256),\n",
32+
" nn.LeakyReLU(0.2),\n",
33+
" nn.Linear(256, 512),\n",
34+
" nn.LeakyReLU(0.2),\n",
35+
" nn.Linear(512, 1024),\n",
36+
" nn.LeakyReLU(0.2),\n",
37+
" nn.Linear(1024, np.prod(img_shape)),\n",
38+
" nn.Tanh()\n",
39+
" )\n",
40+
" def forward(self, x, label):\n",
41+
" batch_size = x.shape[0]\n",
42+
" # generator now uses the latent input noise x and a one hot encoded label for conditioning to generate a fake digit\n",
43+
" x = self.flatten(x)\n",
44+
" x = torch.cat([x, label], dim=1)\n",
45+
" # reshape into a image\n",
46+
" return self.mlp(x).reshape(batch_size, 1, *self.img_shape)"
47+
]
48+
},
49+
{
50+
"cell_type": "code",
51+
"execution_count": null,
52+
"metadata": {},
53+
"outputs": [],
54+
"source": [
55+
"class Discriminator(nn.Module):\n",
56+
" def __init__(self, img_shape):\n",
57+
" super(Discriminator, self).__init__()\n",
58+
"\n",
59+
" self.mlp = nn.Sequential(\n",
60+
" nn.Flatten(),\n",
61+
" nn.Linear(np.prod(img_shape), 1024),\n",
62+
" nn.LeakyReLU(0.2),\n",
63+
" nn.Linear(1024, 512),\n",
64+
" nn.LeakyReLU(0.2),\n",
65+
" nn.Linear(512, 256),\n",
66+
" nn.LeakyReLU(0.2),\n",
67+
" nn.Linear(256, 1),\n",
68+
" nn.Sigmoid()\n",
69+
" )\n",
70+
" def forward(self, x):\n",
71+
" return self.mlp(x)"
72+
]
73+
},
74+
{
75+
"cell_type": "code",
76+
"execution_count": null,
77+
"metadata": {},
78+
"outputs": [],
79+
"source": [
80+
"# load our data\n",
81+
"latent_shape = (28, 28)\n",
82+
"img_shape = (28, 28)\n",
83+
"batch_size = 64\n",
84+
"\n",
85+
"transform = transforms.Compose([\n",
86+
" transforms.ToTensor(),\n",
87+
" transforms.Normalize(mean=(0.5), std=(0.5))])\n",
88+
"train_dataset = torchvision.datasets.MNIST(root=\"./data\", train = True, download=True, transform=transform)der(train_dataset, batch_size=batch_size, shuffle=True)"
89+
]
90+
},
91+
{
92+
"cell_type": "code",
93+
"execution_count": null,
94+
"metadata": {},
95+
"outputs": [],
96+
"source": [
97+
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # for gpu usage if possible\n",
98+
"\n",
99+
"generator = Generator(latent_shape, img_shape)\n",
100+
"discriminator = Discriminator(img_shape)\n",
101+
"\n",
102+
"gen_optim = torch.optim.Adam(generator.parameters(), lr=2e-4)\n",
103+
"disc_optim = torch.optim.Adam(discriminator.parameters(), lr=2e-4)\n",
104+
"\n",
105+
"# use gpu if possible\n",
106+
"generator = generator.to(device)\n",
107+
"discriminator = discriminator.to(device)"
108+
]
109+
},
110+
{
111+
"cell_type": "code",
112+
"execution_count": null,
113+
"metadata": {},
114+
"outputs": [],
115+
"source": [
116+
"def train(generator, discriminator, generator_optim: torch.optim, discriminator_optim: torch.optim, epochs=100):\n",
117+
" adversarial_loss = torch.nn.BCELoss()\n",
118+
" \n",
119+
" for epoch in range(1, epochs+1):\n",
120+
" print(\"Epoch {}\".format(epoch))\n",
121+
" avg_g_loss = 0\n",
122+
" avg_d_loss = 0\n",
123+
" pbar = notebook.tqdm(train_dataloader, total=len(train_dataloader))\n",
124+
" i = 0\n",
125+
" for data in pbar:\n",
126+
" i += 1\n",
127+
" real_images = data[0].to(device)\n",
128+
" labels = data[1].to(device)\n",
129+
"\n",
130+
" one_hot_labels = torch.zeros((len(labels), 10)).to(device)\n",
131+
" for j in range(len(labels)):\n",
132+
" one_hot_labels[j][labels[j]] = 1\n",
133+
"\n",
134+
" ### Train Generator ###\n",
135+
" \n",
136+
" generator_optim.zero_grad()\n",
137+
" \n",
138+
" latent_input = torch.randn((len(real_images), 1, *latent_shape)).to(device)\n",
139+
"\n",
140+
" fake_images = generator(latent_input, one_hot_labels)\n",
141+
"\n",
142+
" fake_res = discriminator(fake_images)\n",
143+
" \n",
144+
" generator_loss = adversarial_loss(fake_res, torch.ones_like(fake_res))\n",
145+
" generator_loss.backward()\n",
146+
" generator_optim.step()\n",
147+
" \n",
148+
" ### Train Discriminator ###\n",
149+
" discriminator_optim.zero_grad()\n",
150+
" \n",
151+
" real_res = discriminator(real_images)\n",
152+
"\n",
153+
" fake_res = discriminator(fake_images.detach())\n",
154+
"\n",
155+
" discriminator_real_loss = adversarial_loss(real_res, torch.ones_like(real_res))\n",
156+
" discriminator_fake_loss = adversarial_loss(fake_res, torch.zeros_like(fake_res))\n",
157+
" discriminator_loss = (discriminator_real_loss + discriminator_fake_loss) / 2\n",
158+
" discriminator_loss.backward()\n",
159+
" discriminator_optim.step()\n",
160+
" \n",
161+
"\n",
162+
" avg_g_loss += generator_loss.item()\n",
163+
" avg_d_loss += discriminator_loss.item()\n",
164+
" pbar.set_postfix({\"G_loss\": generator_loss.item(), \"D_loss\": discriminator_loss.item()})\n",
165+
" print(\"Avg G_loss {} - Avg D_loss {}\".format(avg_g_loss / i, avg_d_loss / i))"
166+
]
167+
},
168+
{
169+
"cell_type": "code",
170+
"execution_count": null,
171+
"metadata": {},
172+
"outputs": [],
173+
"source": [
174+
"# train our generator and discriminator\n",
175+
"# Note: don't always expect loss to go down simultaneously for both models. They are competing against each other! So sometimes one model \n",
176+
"# may perform better than the other\n",
177+
"train(generator=generator, discriminator=discriminator, generator_optim=gen_optim, discriminator_optim=disc_optim)"
178+
]
179+
},
180+
{
181+
"cell_type": "code",
182+
"execution_count": null,
183+
"metadata": {},
184+
"outputs": [],
185+
"source": [
186+
"# test it out!\n",
187+
"latent_input = torch.randn((batch_size, 1, *latent_shape))\n",
188+
"\n",
189+
"# generate one hot encoded labels\n",
190+
"labels = torch.zeros((batch_size))\n",
191+
"one_hot_labels = torch.zeros((batch_size, 10))\n",
192+
"one_hot_labels[torch.arange(batch_size), labels] = 1\n",
193+
"\n",
194+
"test = generator(latent_input.to(device), one_hot_labels)"
195+
]
196+
},
197+
{
198+
"cell_type": "code",
199+
"execution_count": null,
200+
"metadata": {},
201+
"outputs": [],
202+
"source": [
203+
"k = 0\n",
204+
"plt.title(\"Generating a fake {} digit\".format(one_hot_labels[k]))\n",
205+
"plt.imshow(test[k].reshape(28, 28).cpu().detach().numpy())"
206+
]
207+
}
208+
],
209+
"metadata": {
210+
"kernelspec": {
211+
"display_name": "Python 3",
212+
"language": "python",
213+
"name": "python3"
214+
},
215+
"language_info": {
216+
"codemirror_mode": {
217+
"name": "ipython",
218+
"version": 3
219+
},
220+
"file_extension": ".py",
221+
"mimetype": "text/x-python",
222+
"name": "python",
223+
"nbconvert_exporter": "python",
224+
"pygments_lexer": "ipython3",
225+
"version": "3.8.5"
226+
}
227+
},
228+
"nbformat": 4,
229+
"nbformat_minor": 2
230+
}

0 commit comments

Comments
 (0)