1+ {
2+ "metadata" : {
3+ "language_info" : {
4+ "codemirror_mode" : {
5+ "name" : " ipython" ,
6+ "version" : 3
7+ },
8+ "file_extension" : " .py" ,
9+ "mimetype" : " text/x-python" ,
10+ "name" : " python" ,
11+ "nbconvert_exporter" : " python" ,
12+ "pygments_lexer" : " ipython3" ,
13+ "version" : " 3.7.9"
14+ },
15+ "orig_nbformat" : 2 ,
16+ "kernelspec" : {
17+ "name" : " python379jvsc74a57bd0c3e162dbea2a7bde87c7844ddb2ca1843aba08611dc435aaff358cdc82c48661" ,
18+ "display_name" : " Python 3.7.9 64-bit ('three-d': conda)"
19+ }
20+ },
21+ "nbformat" : 4 ,
22+ "nbformat_minor" : 2 ,
23+ "cells" : [
24+ {
25+ "cell_type" : " code" ,
26+ "execution_count" : null ,
27+ "metadata" : {},
28+ "outputs" : [],
29+ "source" : [
30+ " # Import all the necessary libraries\n " ,
31+ " import numpy as np\n " ,
32+ " import torch.nn as nn\n " ,
33+ " import torch.nn.functional as F\n " ,
34+ " import torch\n " ,
35+ " import torchvision\n " ,
36+ " import matplotlib.pyplot as plt\n " ,
37+ " from tqdm import notebook"
38+ ]
39+ },
40+ {
41+ "cell_type" : " code" ,
42+ "execution_count" : null ,
43+ "metadata" : {},
44+ "outputs" : [],
45+ "source" : [
46+ " class Generator(nn.Module):\n " ,
47+ " \"\"\"\n " ,
48+ " Architecture\n " ,
49+ " ------------\n " ,
50+ " Latent Input: latent_shape\n " ,
51+ " Flattened\n " ,
52+ " Linear MLP(256, 512, 1024, prod(img_shape))\n " ,
53+ " \n " ,
54+ " Leaky Relu activation after every layer except last. (Important!)\n " ,
55+ " Tanh activation after last layer to normalize\n " ,
56+ " \"\"\"\n " ,
57+ " def __init__(self, latent_shape, img_shape):\n " ,
58+ " super(Generator, self).__init__()\n " ,
59+ " self.img_shape = img_shape\n " ,
60+ " self.mlp = nn.Sequential(\n " ,
61+ " nn.Flatten(),\n " ,
62+ " nn.Linear(np.prod(latent_shape), 256),\n " ,
63+ " nn.LeakyReLU(0.2),\n " ,
64+ " nn.Linear(256, 512),\n " ,
65+ " nn.LeakyReLU(0.2),\n " ,
66+ " nn.Linear(512, 1024),\n " ,
67+ " nn.LeakyReLU(0.2),\n " ,
68+ " nn.Linear(1024, np.prod(img_shape)),\n " ,
69+ " nn.Tanh()\n " ,
70+ " )\n " ,
71+ " def forward(self, x):\n " ,
72+ " batch_size = x.shape[0]\n " ,
73+ " # reshape into a image\n " ,
74+ " return self.mlp(x).reshape(batch_size, 1, *self.img_shape)"
75+ ]
76+ },
77+ {
78+ "cell_type" : " code" ,
79+ "execution_count" : null ,
80+ "metadata" : {},
81+ "outputs" : [],
82+ "source" : [
83+ " # Define our simple vanilla discriminator\n " ,
84+ " class Discriminator(nn.Module):\n " ,
85+ " \"\"\"\n " ,
86+ " Architecture\n " ,
87+ " ------------\n " ,
88+ " Input Image: img_shape\n " ,
89+ " Flattened\n " ,
90+ " Linear MLP(128, 512, 256, 1)\n " ,
91+ " Relu activation after every layer except last.\n " ,
92+ " Sigmoid activation after last layer to normalize in range 0 to 1\n " ,
93+ " \"\"\"\n " ,
94+ " def __init__(self, img_shape):\n " ,
95+ " super(Discriminator, self).__init__()\n " ,
96+ " \n " ,
97+ " self.mlp = nn.Sequential(\n " ,
98+ " nn.Flatten(),\n " ,
99+ " nn.Linear(np.prod(img_shape), 1024),\n " ,
100+ " nn.LeakyReLU(0.2),\n " ,
101+ " nn.Linear(1024, 512),\n " ,
102+ " nn.LeakyReLU(0.2),\n " ,
103+ " nn.Linear(512, 256),\n " ,
104+ " nn.LeakyReLU(0.2),\n " ,
105+ " nn.Linear(256, 1),\n " ,
106+ " nn.Sigmoid()\n " ,
107+ " )\n " ,
108+ " def forward(self, x):\n " ,
109+ " return self.mlp(x)"
110+ ]
111+ },
112+ {
113+ "cell_type" : " code" ,
114+ "execution_count" : null ,
115+ "metadata" : {},
116+ "outputs" : [],
117+ "source" : [
118+ " # load our data\n " ,
119+ " latent_shape = (28, 28)\n " ,
120+ " img_shape = (28, 28)\n " ,
121+ " batch_size = 64\n " ,
122+ " \n " ,
123+ " transform = transforms.Compose([\n " ,
124+ " transforms.ToTensor(),\n " ,
125+ " transforms.Normalize(mean=(0.5), std=(0.5))])\n " ,
126+ " train_dataset = torchvision.datasets.MNIST(root=\" ./data\" , train = True, download=True, transform=transform)der(train_dataset, batch_size=batch_size, shuffle=True)"
127+ ]
128+ },
129+ {
130+ "cell_type" : " code" ,
131+ "execution_count" : null ,
132+ "metadata" : {},
133+ "outputs" : [],
134+ "source" : [
135+ " device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # for gpu usage if possible\n " ,
136+ " \n " ,
137+ " generator = Generator(latent_shape, img_shape)\n " ,
138+ " discriminator = Discriminator(img_shape)\n " ,
139+ " \n " ,
140+ " gen_optim = torch.optim.Adam(generator.parameters(), lr=2e-4)\n " ,
141+ " disc_optim = torch.optim.Adam(discriminator.parameters(), lr=2e-4)\n " ,
142+ " \n " ,
143+ " # use gpu if possible\n " ,
144+ " generator = generator.to(device)\n " ,
145+ " discriminator = discriminator.to(device)"
146+ ]
147+ },
148+ {
149+ "cell_type" : " code" ,
150+ "execution_count" : null ,
151+ "metadata" : {},
152+ "outputs" : [],
153+ "source" : [
154+ " def train(generator, discriminator, generator_optim: torch.optim, discriminator_optim: torch.optim, epochs=10):\n " ,
155+ " adversarial_loss = torch.nn.BCELoss()\n " ,
156+ " \n " ,
157+ " for epoch in range(1, epochs+1):\n " ,
158+ " print(\" Epoch {}\" .format(epoch))\n " ,
159+ " avg_g_loss = 0\n " ,
160+ " avg_d_loss = 0\n " ,
161+ " pbar = notebook.tqdm(train_dataloader, total=len(train_dataloader))\n " ,
162+ " i = 0\n " ,
163+ " for data in pbar:\n " ,
164+ " i += 1\n " ,
165+ " real_images = data[0].to(device)\n " ,
166+ " ### Train Generator ###\n " ,
167+ " generator_optim.zero_grad()\n " ,
168+ " \n " ,
169+ " latent_input = torch.randn((batch_size, 1, *latent_shape)).to(device)\n " ,
170+ " fake_images = generator(latent_input)\n " ,
171+ " \n " ,
172+ " fake_res = discriminator(fake_images)\n " ,
173+ " \n " ,
174+ " generator_loss = adversarial_loss(fake_res, torch.ones_like(fake_res))\n " ,
175+ " generator_loss.backward()\n " ,
176+ " generator_optim.step()\n " ,
177+ " \n " ,
178+ " ### Train Discriminator ###\n " ,
179+ " discriminator_optim.zero_grad()\n " ,
180+ " \n " ,
181+ " real_res = discriminator(real_images)\n " ,
182+ " \n " ,
183+ " fake_res = discriminator(fake_images.detach())\n " ,
184+ " \n " ,
185+ " discriminator_real_loss = adversarial_loss(real_res, torch.ones_like(real_res))\n " ,
186+ " discriminator_fake_loss = adversarial_loss(fake_res, torch.zeros_like(real_res))\n " ,
187+ " discriminator_loss = (discriminator_real_loss + discriminator_fake_loss) / 2\n " ,
188+ " discriminator_loss.backward()\n " ,
189+ " discriminator_optim.step()\n " ,
190+ " \n " ,
191+ " \n " ,
192+ " avg_g_loss += generator_loss.item()\n " ,
193+ " avg_d_loss += discriminator_loss.item()\n " ,
194+ " pbar.set_postfix({\" G_loss\" : generator_loss.item(), \" D_loss\" : discriminator_loss.item()})"
195+ ]
196+ },
197+ {
198+ "cell_type" : " code" ,
199+ "execution_count" : null ,
200+ "metadata" : {},
201+ "outputs" : [],
202+ "source" : [
203+ " # train our generator and discriminator\n " ,
204+ " # Note: don't always expect loss to go down simultaneously for both models. They are competing against each other! So sometimes one model \n " ,
205+ " # may perform better than the other\n " ,
206+ " train(generator=generator, discriminator=discriminator, generator_optim=gen_optim, discriminator_optim=disc_optim)"
207+ ]
208+ },
209+ {
210+ "cell_type" : " code" ,
211+ "execution_count" : null ,
212+ "metadata" : {},
213+ "outputs" : [],
214+ "source" : [
215+ " # test it out!\n " ,
216+ " latent_input = torch.randn((batch_size, 1, *latent_shape))\n " ,
217+ " test = generator(latent_input.to(device))\n " ,
218+ " plt.imshow(test[0].reshape(28, 28).cpu().detach().numpy())"
219+ ]
220+ }
221+ ]
222+ }
0 commit comments