1+ {
2+ "metadata" : {
3+ "language_info" : {
4+ "codemirror_mode" : {
5+ "name" : " ipython" ,
6+ "version" : 3
7+ },
8+ "file_extension" : " .py" ,
9+ "mimetype" : " text/x-python" ,
10+ "name" : " python" ,
11+ "nbconvert_exporter" : " python" ,
12+ "pygments_lexer" : " ipython3" ,
13+ "version" : " 3.7.9"
14+ },
15+ "orig_nbformat" : 2 ,
16+ "kernelspec" : {
17+ "name" : " python379jvsc74a57bd0c3e162dbea2a7bde87c7844ddb2ca1843aba08611dc435aaff358cdc82c48661" ,
18+ "display_name" : " Python 3.7.9 64-bit ('three-d': conda)"
19+ }
20+ },
21+ "nbformat" : 4 ,
22+ "nbformat_minor" : 2 ,
23+ "cells" : [
24+ {
25+ "cell_type" : " code" ,
26+ "execution_count" : 5 ,
27+ "metadata" : {},
28+ "outputs" : [],
29+ "source" : [
30+ " # Import all the necessary libraries\n " ,
31+ " import numpy as np\n " ,
32+ " import torch.nn as nn\n " ,
33+ " import torch.nn.functional as F\n " ,
34+ " import torch\n " ,
35+ " import torchvision\n " ,
36+ " import matplotlib.pyplot as plt\n " ,
37+ " from tqdm import notebook"
38+ ]
39+ },
40+ {
41+ "cell_type" : " code" ,
42+ "execution_count" : 6 ,
43+ "metadata" : {},
44+ "outputs" : [
45+ {
46+ "output_type" : " error" ,
47+ "ename" : " SyntaxError" ,
48+ "evalue" : " invalid syntax (<ipython-input-6-df9263a4d089>, line 25)" ,
49+ "traceback" : [
50+ " \u001b [0;36m File \u001b [0;32m\" <ipython-input-6-df9263a4d089>\" \u001b [0;36m, line \u001b [0;32m25\u001b [0m\n \u001b [0;31m nn.Sigmoid()\u001b [0m\n \u001b [0m ^\u001b [0m\n \u001b [0;31mSyntaxError\u001b [0m\u001b [0;31m:\u001b [0m invalid syntax\n "
51+ ]
52+ }
53+ ],
54+ "source" : [
55+ " # Define our simple vanilla generator\n " ,
56+ " class Generator(nn.Module):\n " ,
57+ " \"\"\"\n " ,
58+ " Architecture\n " ,
59+ " ------------\n " ,
60+ " Latent Input: latent_shape\n " ,
61+ " Flattened\n " ,
62+ " Linear MLP(256, 512, 1024, prod(img_shape))\n " ,
63+ " \n " ,
64+ " Leaky Relu activation after every layer except last. (Important!)\n " ,
65+ " Tanh activation after last layer to normalize\n " ,
66+ " \"\"\"\n " ,
67+ " def __init__(self, latent_shape, img_shape):\n " ,
68+ " super(Generator, self).__init__()\n " ,
69+ " self.img_shape = img_shape\n " ,
70+ " self.mlp = nn.Sequential(\n " ,
71+ " nn.Flatten(),\n " ,
72+ " nn.Linear(np.prod(latent_shape), 256),\n " ,
73+ " nn.LeakyReLU(0.2),\n " ,
74+ " nn.Linear(256, 512),\n " ,
75+ " nn.LeakyReLU(0.2),\n " ,
76+ " nn.Linear(512, 1024),\n " ,
77+ " nn.LeakyReLU(0.2),\n " ,
78+ " nn.Linear(1024, np.prod(img_shape)),\n " ,
79+ " nn.Tanh()\n " ,
80+ " )\n " ,
81+ " def forward(self, x):\n " ,
82+ " batch_size = x.shape[0]\n " ,
83+ " # reshape into a image\n " ,
84+ " return self.mlp(x).reshape(batch_size, 1, *self.img_shape)"
85+ ]
86+ },
87+ {
88+ "cell_type" : " code" ,
89+ "execution_count" : 7 ,
90+ "metadata" : {},
91+ "outputs" : [],
92+ "source" : [
93+ " # Define our simple vanilla discriminator\n " ,
94+ " class Discriminator(nn.Module):\n " ,
95+ " \"\"\"\n " ,
96+ " Architecture\n " ,
97+ " ------------\n " ,
98+ " Input Image: img_shape\n " ,
99+ " Flattened\n " ,
100+ " Linear MLP(128, 512, 256, 1)\n " ,
101+ " Relu activation after every layer except last.\n " ,
102+ " Sigmoid activation after last layer to normalize in range 0 to 1\n " ,
103+ " \"\"\"\n " ,
104+ " def __init__(self, img_shape):\n " ,
105+ " super(Discriminator, self).__init__()\n " ,
106+ " \n " ,
107+ " self.mlp = nn.Sequential(\n " ,
108+ " nn.Flatten(),\n " ,
109+ " nn.Linear(np.prod(img_shape), 128),\n " ,
110+ " nn.ReLU(),\n " ,
111+ " nn.Linear(128, 512),\n " ,
112+ " nn.ReLU(),\n " ,
113+ " nn.Linear(512, 256),\n " ,
114+ " nn.ReLU(),\n " ,
115+ " nn.Linear(256, 1),\n " ,
116+ " nn.Sigmoid()\n " ,
117+ " )\n " ,
118+ " def forward(self, x):\n " ,
119+ " return self.mlp(x)"
120+ ]
121+ },
122+ {
123+ "cell_type" : " code" ,
124+ "execution_count" : null ,
125+ "metadata" : {},
126+ "outputs" : [],
127+ "source" : [
128+ " # load our data\n " ,
129+ " latent_shape = (28, 28)\n " ,
130+ " img_shape = (28, 28)\n " ,
131+ " batch_size = 64\n " ,
132+ " \n " ,
133+ " transform = torchvision.transforms.Compose(\n " ,
134+ " [\n " ,
135+ " torchvision.transforms.ToTensor() # converts the PIL Image format to a pytorch tensor\n " ,
136+ " ]\n " ,
137+ " )\n " ,
138+ " train_dataset = torchvision.datasets.MNIST(root=\" ./data\" , train = True, download=True, transform=transform)\n " ,
139+ " train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)"
140+ ]
141+ },
142+ {
143+ "cell_type" : " code" ,
144+ "execution_count" : null ,
145+ "metadata" : {},
146+ "outputs" : [],
147+ "source" : [
148+ " device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # for gpu usage if possible\n " ,
149+ " \n " ,
150+ " generator = Generator(latent_shape, img_shape)\n " ,
151+ " discriminator = Discriminator(img_shape)\n " ,
152+ " \n " ,
153+ " gen_optim = torch.optim.Adam(generator.parameters(), lr=2e-4)\n " ,
154+ " disc_optim = torch.optim.Adam(discriminator.parameters(), lr=2e-4)\n " ,
155+ " \n " ,
156+ " # use gpu if possible\n " ,
157+ " generator = generator.to(device)\n " ,
158+ " discriminator = discriminator.to(device)"
159+ ]
160+ },
161+ {
162+ "cell_type" : " code" ,
163+ "execution_count" : null ,
164+ "metadata" : {},
165+ "outputs" : [],
166+ "source" : [
167+ " def train(generator, discriminator, generator_optim: torch.optim, discriminator_optim: torch.optim, epochs=10):\n " ,
168+ " adversarial_loss = torch.nn.BCELoss()\n " ,
169+ " \n " ,
170+ " for epoch in range(1, epochs+1):\n " ,
171+ " print(\" Epoch {}\" .format(epoch))\n " ,
172+ " avg_g_loss = 0\n " ,
173+ " avg_d_loss = 0\n " ,
174+ " pbar = notebook.tqdm(train_dataloader, total=len(train_dataloader))\n " ,
175+ " i = 0\n " ,
176+ " for data in pbar:\n " ,
177+ " i += 1\n " ,
178+ " real_images = data[0].to(device)\n " ,
179+ " ### Train Generator ###\n " ,
180+ " generator_optim.zero_grad()\n " ,
181+ " \n " ,
182+ " latent_input = torch.randn((batch_size, 1, *latent_shape)).to(device)\n " ,
183+ " fake_images = generator(latent_input)\n " ,
184+ " \n " ,
185+ " fake_res = discriminator(fake_images)\n " ,
186+ " \n " ,
187+ " generator_loss = adversarial_loss(fake_res, torch.ones_like(fake_res))\n " ,
188+ " generator_loss.backward()\n " ,
189+ " generator_optim.step()\n " ,
190+ " \n " ,
191+ " ### Train Discriminator ###\n " ,
192+ " discriminator_optim.zero_grad()\n " ,
193+ " \n " ,
194+ " real_res = discriminator(real_images)\n " ,
195+ " \n " ,
196+ " fake_res = discriminator(fake_images.detach())\n " ,
197+ " \n " ,
198+ " discriminator_real_loss = adversarial_loss(real_res, torch.ones_like(real_res))\n " ,
199+ " discriminator_fake_loss = adversarial_loss(fake_res, torch.zeros_like(real_res))\n " ,
200+ " discriminator_loss = (discriminator_real_loss + discriminator_fake_loss) / 2\n " ,
201+ " discriminator_loss.backward()\n " ,
202+ " discriminator_optim.step()\n " ,
203+ " \n " ,
204+ " \n " ,
205+ " avg_g_loss += generator_loss.item()\n " ,
206+ " avg_d_loss += discriminator_loss.item()\n " ,
207+ " pbar.set_postfix({\" G_loss\" : generator_loss.item(), \" D_loss\" : discriminator_loss.item()})"
208+ ]
209+ },
210+ {
211+ "cell_type" : " code" ,
212+ "execution_count" : null ,
213+ "metadata" : {},
214+ "outputs" : [],
215+ "source" : [
216+ " # train our generator and discriminator\n " ,
217+ " # Note: don't always expect loss to go down simultaneously for both models. They are competing against each other! So sometimes one model \n " ,
218+ " # may perform better than the other\n " ,
219+ " train(generator=generator, discriminator=discriminator, generator_optim=gen_optim, discriminator_optim=disc_optim)"
220+ ]
221+ },
222+ {
223+ "cell_type" : " code" ,
224+ "execution_count" : null ,
225+ "metadata" : {},
226+ "outputs" : [],
227+ "source" : [
228+ " # test it out!\n " ,
229+ " latent_input = torch.randn((batch_size, 1, *latent_shape))\n " ,
230+ " test = generator(latent_input.to(device))\n " ,
231+ " plt.imshow(test[0].reshape(28, 28).cpu().detach().numpy())"
232+ ]
233+ }
234+ ]
235+ }
0 commit comments