Skip to content

Commit c646ef6

Browse files
checked GAN code
1 parent b6985ec commit c646ef6

File tree

14 files changed

+225
-270
lines changed

14 files changed

+225
-270
lines changed

ML/Pytorch/GANs/1. SimpleGAN/fc_gan.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
"""
2+
Simple GAN using fully connected layers
3+
4+
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
5+
* 2020-11-01: Initial coding
6+
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
7+
"""
8+
9+
110
import torch
211
import torch.nn as nn
312
import torch.optim as optim
@@ -48,7 +57,10 @@ def forward(self, x):
4857
gen = Generator(z_dim, image_dim).to(device)
4958
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
5059
transforms = transforms.Compose(
51-
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)),]
60+
[
61+
transforms.ToTensor(),
62+
transforms.Normalize((0.5,), (0.5,)),
63+
]
5264
)
5365

5466
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
@@ -104,4 +116,4 @@ def forward(self, x):
104116
writer_real.add_image(
105117
"Mnist Real Images", img_grid_real, global_step=step
106118
)
107-
step += 1
119+
step += 1

ML/Pytorch/GANs/2. DCGAN/model.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
"""
22
Discriminator and Generator implementation from DCGAN paper
3+
4+
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
5+
* 2020-11-01: Initial coding
6+
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
37
"""
48

59
import torch
@@ -11,9 +15,7 @@ def __init__(self, channels_img, features_d):
1115
super(Discriminator, self).__init__()
1216
self.disc = nn.Sequential(
1317
# input: N x channels_img x 64 x 64
14-
nn.Conv2d(
15-
channels_img, features_d, kernel_size=4, stride=2, padding=1
16-
),
18+
nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
1719
nn.LeakyReLU(0.2),
1820
# _block(in_channels, out_channels, kernel_size, stride, padding)
1921
self._block(features_d, features_d * 2, 4, 2, 1),
@@ -34,7 +36,7 @@ def _block(self, in_channels, out_channels, kernel_size, stride, padding):
3436
padding,
3537
bias=False,
3638
),
37-
#nn.BatchNorm2d(out_channels),
39+
# nn.BatchNorm2d(out_channels),
3840
nn.LeakyReLU(0.2),
3941
)
4042

@@ -68,7 +70,7 @@ def _block(self, in_channels, out_channels, kernel_size, stride, padding):
6870
padding,
6971
bias=False,
7072
),
71-
#nn.BatchNorm2d(out_channels),
73+
# nn.BatchNorm2d(out_channels),
7274
nn.ReLU(),
7375
)
7476

@@ -82,6 +84,7 @@ def initialize_weights(model):
8284
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
8385
nn.init.normal_(m.weight.data, 0.0, 0.02)
8486

87+
8588
def test():
8689
N, in_channels, H, W = 8, 3, 64, 64
8790
noise_dim = 100
@@ -91,6 +94,8 @@ def test():
9194
gen = Generator(noise_dim, in_channels, 8)
9295
z = torch.randn((N, noise_dim, 1, 1))
9396
assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
97+
print("Success, tests passed!")
9498

9599

96-
# test()
100+
if __name__ == "__main__":
101+
test()

ML/Pytorch/GANs/2. DCGAN/train.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
"""
22
Training of DCGAN network on MNIST dataset with Discriminator
33
and Generator imported from models.py
4+
5+
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
6+
* 2020-11-01: Initial coding
7+
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
48
"""
59

610
import torch
@@ -35,11 +39,12 @@
3539
)
3640

3741
# If you train on MNIST, remember to set channels_img to 1
38-
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms,
39-
download=True)
42+
dataset = datasets.MNIST(
43+
root="dataset/", train=True, transform=transforms, download=True
44+
)
4045

4146
# comment mnist above and uncomment below if train on CelebA
42-
#dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
47+
# dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
4348
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
4449
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
4550
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
@@ -92,14 +97,10 @@
9297
with torch.no_grad():
9398
fake = gen(fixed_noise)
9499
# take out (up to) 32 examples
95-
img_grid_real = torchvision.utils.make_grid(
96-
real[:32], normalize=True
97-
)
98-
img_grid_fake = torchvision.utils.make_grid(
99-
fake[:32], normalize=True
100-
)
100+
img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
101+
img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
101102

102103
writer_real.add_image("Real", img_grid_real, global_step=step)
103104
writer_fake.add_image("Fake", img_grid_fake, global_step=step)
104105

105-
step += 1
106+
step += 1

ML/Pytorch/GANs/3. WGAN/model.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
Discriminator and Generator implementation from DCGAN paper,
33
with removed Sigmoid() as output from Discriminator (and therefor
44
it should be called critic)
5+
6+
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
7+
* 2020-11-01: Initial coding
8+
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
59
"""
610

711
import torch
@@ -93,6 +97,7 @@ def test():
9397
gen = Generator(noise_dim, in_channels, 8)
9498
z = torch.randn((N, noise_dim, 1, 1))
9599
assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"
100+
print("Success, tests passed!")
96101

97-
98-
# test()
102+
if __name__ == "__main__":
103+
test()

ML/Pytorch/GANs/3. WGAN/train.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
"""
22
Training of DCGAN network with WGAN loss
3+
4+
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
5+
* 2020-11-01: Initial coding
6+
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
37
"""
48

59
import torch
@@ -9,6 +13,7 @@
913
import torchvision.datasets as datasets
1014
import torchvision.transforms as transforms
1115
from torch.utils.data import DataLoader
16+
from tqdm import tqdm
1217
from torch.utils.tensorboard import SummaryWriter
1318
from model import Discriminator, Generator, initialize_weights
1419

@@ -61,7 +66,7 @@
6166

6267
for epoch in range(NUM_EPOCHS):
6368
# Target labels not needed! <3 unsupervised
64-
for batch_idx, (data, _) in enumerate(loader):
69+
for batch_idx, (data, _) in enumerate(tqdm(loader)):
6570
data = data.to(device)
6671
cur_batch_size = data.shape[0]
6772

@@ -111,4 +116,4 @@
111116

112117
step += 1
113118
gen.train()
114-
critic.train()
119+
critic.train()

ML/Pytorch/GANs/4. WGAN-GP/model.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
"""
22
Discriminator and Generator implementation from DCGAN paper
3+
4+
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
5+
* 2020-11-01: Initial coding
6+
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
37
"""
48

59
import torch
@@ -24,7 +28,12 @@ def __init__(self, channels_img, features_d):
2428
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
2529
return nn.Sequential(
2630
nn.Conv2d(
27-
in_channels, out_channels, kernel_size, stride, padding, bias=False,
31+
in_channels,
32+
out_channels,
33+
kernel_size,
34+
stride,
35+
padding,
36+
bias=False,
2837
),
2938
nn.InstanceNorm2d(out_channels, affine=True),
3039
nn.LeakyReLU(0.2),
@@ -53,7 +62,12 @@ def __init__(self, channels_noise, channels_img, features_g):
5362
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
5463
return nn.Sequential(
5564
nn.ConvTranspose2d(
56-
in_channels, out_channels, kernel_size, stride, padding, bias=False,
65+
in_channels,
66+
out_channels,
67+
kernel_size,
68+
stride,
69+
padding,
70+
bias=False,
5771
),
5872
nn.BatchNorm2d(out_channels),
5973
nn.ReLU(),

ML/Pytorch/GANs/4. WGAN-GP/train.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
"""
22
Training of WGAN-GP
3+
4+
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
5+
* 2020-11-01: Initial coding
6+
* 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
37
"""
48

59
import torch
@@ -10,6 +14,7 @@
1014
import torchvision.transforms as transforms
1115
from torch.utils.data import DataLoader
1216
from torch.utils.tensorboard import SummaryWriter
17+
from tqdm import tqdm
1318
from utils import gradient_penalty, save_checkpoint, load_checkpoint
1419
from model import Discriminator, Generator, initialize_weights
1520

@@ -31,13 +36,14 @@
3136
transforms.Resize(IMAGE_SIZE),
3237
transforms.ToTensor(),
3338
transforms.Normalize(
34-
[0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]),
39+
[0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
40+
),
3541
]
3642
)
3743

3844
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
3945
# comment mnist above and uncomment below for training on CelebA
40-
#dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
46+
# dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
4147
loader = DataLoader(
4248
dataset,
4349
batch_size=BATCH_SIZE,
@@ -66,7 +72,7 @@
6672

6773
for epoch in range(NUM_EPOCHS):
6874
# Target labels not needed! <3 unsupervised
69-
for batch_idx, (real, _) in enumerate(loader):
75+
for batch_idx, (real, _) in enumerate(tqdm(loader)):
7076
real = real.to(device)
7177
cur_batch_size = real.shape[0]
7278

@@ -108,4 +114,4 @@
108114
writer_real.add_image("Real", img_grid_real, global_step=step)
109115
writer_fake.add_image("Fake", img_grid_fake, global_step=step)
110116

111-
step += 1
117+
step += 1

ML/Pytorch/GANs/CycleGAN/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
LAMBDA_CYCLE = 10
1212
NUM_WORKERS = 4
1313
NUM_EPOCHS = 10
14-
LOAD_MODEL = True
14+
LOAD_MODEL = False
1515
SAVE_MODEL = True
1616
CHECKPOINT_GEN_H = "genh.pth.tar"
1717
CHECKPOINT_GEN_Z = "genz.pth.tar"
@@ -24,6 +24,6 @@
2424
A.HorizontalFlip(p=0.5),
2525
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
2626
ToTensorV2(),
27-
],
27+
],
2828
additional_targets={"image0": "image"},
29-
)
29+
)

ML/Pytorch/GANs/CycleGAN/discriminator_model.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,28 @@
1+
"""
2+
Discriminator model for CycleGAN
3+
4+
Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
5+
* 2020-11-05: Initial coding
6+
* 2022-12-21: Small revision of code, checked that it works with latest PyTorch version
7+
"""
8+
19
import torch
210
import torch.nn as nn
311

12+
413
class Block(nn.Module):
514
def __init__(self, in_channels, out_channels, stride):
615
super().__init__()
716
self.conv = nn.Sequential(
8-
nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=True, padding_mode="reflect"),
17+
nn.Conv2d(
18+
in_channels,
19+
out_channels,
20+
4,
21+
stride,
22+
1,
23+
bias=True,
24+
padding_mode="reflect",
25+
),
926
nn.InstanceNorm2d(out_channels),
1027
nn.LeakyReLU(0.2, inplace=True),
1128
)
@@ -32,15 +49,27 @@ def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
3249
layers = []
3350
in_channels = features[0]
3451
for feature in features[1:]:
35-
layers.append(Block(in_channels, feature, stride=1 if feature==features[-1] else 2))
52+
layers.append(
53+
Block(in_channels, feature, stride=1 if feature == features[-1] else 2)
54+
)
3655
in_channels = feature
37-
layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))
56+
layers.append(
57+
nn.Conv2d(
58+
in_channels,
59+
1,
60+
kernel_size=4,
61+
stride=1,
62+
padding=1,
63+
padding_mode="reflect",
64+
)
65+
)
3866
self.model = nn.Sequential(*layers)
3967

4068
def forward(self, x):
4169
x = self.initial(x)
4270
return torch.sigmoid(self.model(x))
4371

72+
4473
def test():
4574
x = torch.randn((5, 3, 256, 256))
4675
model = Discriminator(in_channels=3)
@@ -50,4 +79,3 @@ def test():
5079

5180
if __name__ == "__main__":
5281
test()
53-

0 commit comments

Comments
 (0)