Skip to content

Commit dd8f9e9

Browse files
committed
changed models to ResNet-based encoder-decoder generator and a fully-convolutional (PathGAN) discriminator
1 parent cb0c20b commit dd8f9e9

File tree

1 file changed

+194
-58
lines changed

1 file changed

+194
-58
lines changed

src/aging_gan/model.py

Lines changed: 194 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,211 @@
11
import torch.nn as nn
2-
import torch.nn.utils as nn_utils
3-
import segmentation_models_pytorch as smp
2+
import torch.nn.functional as F
43

4+
# import torch.nn.utils as nn_utils
5+
# import segmentation_models_pytorch as smp
56

6-
# Discriminator: PatchGAN 70x70
7-
class PatchDiscriminator(nn.Module):
8-
def __init__(self, in_channels=3, ndf=48):
7+
8+
class ResidualBlock(nn.Module):
9+
def __init__(self, in_features):
910
super().__init__()
10-
layers = [
11-
nn_utils.spectral_norm(
11+
12+
conv_block = [
13+
nn.ReflectionPad2d(1), # (B, C, H+2, W+2)
14+
nn.Conv2d(in_features, in_features, 3), # (B, C, H, W)
15+
nn.BatchNorm2d(in_features), # (B, C, H, W)
16+
nn.ReLU(), # (B, C, H, W)
17+
nn.ReflectionPad2d(1), # (B, C, H+2, W+2)
18+
nn.Conv2d(in_features, in_features, 3), # (B, C, H, W)
19+
nn.BatchNorm2d(in_features),
20+
] # (B, C, H, W)
21+
22+
self.conv_block = nn.Sequential(*conv_block)
23+
24+
def forward(self, x):
25+
return x + self.conv_block(x) # skip connection
26+
27+
28+
class Generator(nn.Module):
29+
def __init__(self, ngf, n_residual_blocks=9):
30+
super().__init__()
31+
32+
# Initial convlution block
33+
model = [
34+
nn.ReflectionPad2d(
35+
3
36+
), # (B, 3, H+6, W+6), applies 2D "reflection" padding of 3 pixels on all four sides of image
37+
nn.Conv2d(
38+
3, ngf, 7
39+
), # (B, ngf, H, W), 3 in_channels, ngf out_channels, kernel size 7 (keeps same image size)
40+
nn.BatchNorm2d(
41+
ngf
42+
), # (B, ngf, H, W), normalized for each ngf across all B, H, W
43+
nn.ReLU(),
44+
] # (B, ngf, H, W)
45+
46+
# Downsampling
47+
in_features = ngf # number of generator filters
48+
out_features = in_features * 2
49+
for _ in range(2):
50+
model += [
1251
nn.Conv2d(
13-
in_channels=in_channels,
14-
out_channels=ndf,
15-
kernel_size=4,
16-
stride=2,
17-
padding=1,
18-
)
19-
),
20-
nn.LeakyReLU(0.2),
21-
]
22-
nf = ndf
23-
for i in range(3):
24-
stride = 2 if i < 2 else 1
25-
layers += [
26-
nn_utils.spectral_norm(nn.Conv2d(nf, nf * 2, 4, stride, 1)),
27-
nn.InstanceNorm2d(nf * 2, affine=True),
28-
nn.LeakyReLU(0.2),
29-
]
30-
nf *= 2
31-
layers += [nn_utils.spectral_norm(nn.Conv2d(nf, 1, 4, 1, 1))]
32-
self.model = nn.Sequential(*layers)
52+
in_features, out_features, 3, stride=2, padding=1
53+
), # (B, in_features*2, H//2, W//2), doubles number of channels and reduces H, W by half
54+
nn.BatchNorm2d(out_features), # (B, in_features*2, H//2, W//2)
55+
nn.ReLU(),
56+
] # (B, in_features*2, H//2, W//2)
57+
in_features = out_features
58+
out_features = in_features * 2
59+
60+
# Residual blocks
61+
for _ in range(n_residual_blocks):
62+
model += [
63+
ResidualBlock(in_features)
64+
] # (B, in_features, H, W), returns same size as input
65+
66+
# Upsampling
67+
out_features = in_features // 2
68+
for _ in range(2):
69+
model += [
70+
nn.ConvTranspose2d(
71+
in_features, out_features, 3, stride=2, padding=1, output_padding=1
72+
), # (B, in_features//2, H*2, W*2), upsamples to twice the H, W with half the channels
73+
nn.BatchNorm2d(out_features), # (B, in_features//2, H*2, W*2)
74+
nn.ReLU(),
75+
] # (B, in_features//2, H*2, W*2)
76+
in_features = out_features
77+
out_features = in_features // 2
78+
79+
# Output layer
80+
model += [
81+
nn.ReflectionPad2d(3), # (B, in_features, H+6, W+6)
82+
nn.Conv2d(ngf, 3, 7), # (B, 3, H, W)
83+
nn.Tanh(),
84+
] # (B, 3, H, W), passed tanh activation
85+
86+
self.model = nn.Sequential(*model)
3387

3488
def forward(self, x):
3589
return self.model(x)
3690

3791

38-
# Freeze encoder of model so that model can learn "aging" during the first epoch
39-
def freeze_encoders(G, F):
40-
for param in G.encoder.parameters():
41-
param.requires_grad = False
42-
for param in F.encoder.parameters():
43-
param.requires_grad = False
92+
class Discriminator(nn.Module):
93+
def __init__(self, ndf):
94+
super().__init__()
95+
96+
model = [
97+
nn.Conv2d(
98+
3, ndf, 4, stride=2, padding=1
99+
), # (B, ndf, H//2, W//2), channel from 3 -> ndf
100+
nn.LeakyReLU(0.2, inplace=True),
101+
] # (B, ndf, H//2, W//2)
102+
103+
model += [
104+
nn.Conv2d(ndf, ndf * 2, 4, stride=2, padding=1), # (B, ndf * 2, H//4, W//4)
105+
nn.BatchNorm2d(ndf * 2),
106+
nn.LeakyReLU(0.2, inplace=True),
107+
]
44108

109+
model += [
110+
nn.Conv2d(
111+
ndf * 2, ndf * 4, 4, stride=2, padding=1
112+
), # (B, ndf * 4, H//8, W//8)
113+
nn.InstanceNorm2d(ndf * 4),
114+
nn.LeakyReLU(0.2, inplace=True),
115+
]
116+
117+
model += [
118+
nn.Conv2d(ndf * 4, ndf * 8, 4, padding=1), # (B, ndf * 8, H//8-1, W//8-1)
119+
nn.InstanceNorm2d(ndf * 8),
120+
nn.LeakyReLU(0.2, inplace=True),
121+
]
122+
123+
# FCN classification layer
124+
model += [nn.Conv2d(ndf * 8, 1, 4, padding=1)] # (B, 1, H//8-2, W//8-2)
125+
126+
self.model = nn.Sequential(*model)
127+
128+
def forward(self, x):
129+
# x: (B, 3, H, W)
130+
x = self.model(x) # (B, 1, H//8-2, W//8-2)
131+
# Average pooling and flatten
132+
return F.avg_pool2d(x, x.size()[2:]).view(
133+
x.size()[0], -1
134+
) # global average -> (B, 1, 1, 1) -> flatten to (B, 1)
45135

46-
# Unfreeze encoders later
47-
def unfreeze_encoders(G, F):
48-
for param in G.encoder.parameters():
49-
param.requires_grad = True
50-
for param in F.encoder.parameters():
51-
param.requires_grad = True
136+
137+
# # Discriminator: PatchGAN 70x70
138+
# class PatchDiscriminator(nn.Module):
139+
# def __init__(self, in_channels=3, ndf=48):
140+
# super().__init__()
141+
# layers = [
142+
# nn_utils.spectral_norm(
143+
# nn.Conv2d(
144+
# in_channels=in_channels,
145+
# out_channels=ndf,
146+
# kernel_size=4,
147+
# stride=2,
148+
# padding=1,
149+
# )
150+
# ),
151+
# nn.LeakyReLU(0.2),
152+
# ]
153+
# nf = ndf
154+
# for i in range(3):
155+
# stride = 2 if i < 2 else 1
156+
# layers += [
157+
# nn_utils.spectral_norm(nn.Conv2d(nf, nf * 2, 4, stride, 1)),
158+
# nn.InstanceNorm2d(nf * 2, affine=True),
159+
# nn.LeakyReLU(0.2),
160+
# ]
161+
# nf *= 2
162+
# layers += [nn_utils.spectral_norm(nn.Conv2d(nf, 1, 4, 1, 1))]
163+
# self.model = nn.Sequential(*layers)
164+
165+
# def forward(self, x):
166+
# return self.model(x)
167+
168+
169+
# # Freeze encoder of model so that model can learn "aging" during the first epoch
170+
# def freeze_encoders(G, F):
171+
# for param in G.encoder.parameters():
172+
# param.requires_grad = False
173+
# for param in F.encoder.parameters():
174+
# param.requires_grad = False
175+
176+
177+
# # Unfreeze encoders later
178+
# def unfreeze_encoders(G, F):
179+
# for param in G.encoder.parameters():
180+
# param.requires_grad = True
181+
# for param in F.encoder.parameters():
182+
# param.requires_grad = True
52183

53184

54185
# Initialize and return the generators and discriminators used for training
55-
def initialize_models():
56-
# initialize the generators
57-
G = smp.Unet(
58-
encoder_name="resnet34",
59-
encoder_weights="imagenet", # preload low-level filters
60-
in_channels=3, # RGB input
61-
classes=3, # RGB output
62-
)
63-
64-
F = smp.Unet(
65-
encoder_name="resnet34",
66-
encoder_weights="imagenet", # preload low-level filters
67-
in_channels=3, # RGB input
68-
classes=3, # RGB output
69-
)
70-
71-
# initlize the discriminator
72-
DX = PatchDiscriminator()
73-
DY = PatchDiscriminator()
186+
def initialize_models(
187+
ngf: int = 32,
188+
ndf: int = 32,
189+
n_blocks: int = 9,
190+
):
191+
# G = smp.Unet(
192+
# encoder_name="resnet34",
193+
# encoder_weights="imagenet", # preload low-level filters
194+
# in_channels=3, # RGB input
195+
# classes=3, # RGB output
196+
# )
197+
198+
# F = smp.Unet(
199+
# encoder_name="resnet34",
200+
# encoder_weights="imagenet", # preload low-level filters
201+
# in_channels=3, # RGB input
202+
# classes=3, # RGB output
203+
# )
204+
205+
# initialize the generators and discriminators
206+
G = Generator(ngf, n_blocks)
207+
F = Generator(ngf, n_blocks)
208+
DX = Discriminator(ndf)
209+
DY = Discriminator(ndf)
74210

75211
return G, F, DX, DY

0 commit comments

Comments
 (0)