44import torch .nn as nn
55import torch .nn .functional as F
66
7- # import torch.nn.utils as nn_utils
8- # import segmentation_models_pytorch as smp
9-
107
118class ResidualBlock (nn .Module ):
129 """Simple residual block with two conv layers."""
@@ -138,56 +135,12 @@ def __init__(self, ndf: int) -> None:
138135
139136 def forward (self , x : Tensor ) -> Tensor :
140137 """Return discriminator logits for input ``x``."""
141- x = self .model (x )
142- return F .avg_pool2d (x , x .size ()[2 :]).view (x .size ()[0 ], - 1 )
143-
144-
145- # # Discriminator: PatchGAN 70x70
146- # class PatchDiscriminator(nn.Module):
147- # def __init__(self, in_channels=3, ndf=48):
148- # super().__init__()
149- # layers = [
150- # nn_utils.spectral_norm(
151- # nn.Conv2d(
152- # in_channels=in_channels,
153- # out_channels=ndf,
154- # kernel_size=4,
155- # stride=2,
156- # padding=1,
157- # )
158- # ),
159- # nn.LeakyReLU(0.2),
160- # ]
161- # nf = ndf
162- # for i in range(3):
163- # stride = 2 if i < 2 else 1
164- # layers += [
165- # nn_utils.spectral_norm(nn.Conv2d(nf, nf * 2, 4, stride, 1)),
166- # nn.InstanceNorm2d(nf * 2, affine=True),
167- # nn.LeakyReLU(0.2),
168- # ]
169- # nf *= 2
170- # layers += [nn_utils.spectral_norm(nn.Conv2d(nf, 1, 4, 1, 1))]
171- # self.model = nn.Sequential(*layers)
172-
173- # def forward(self, x):
174- # return self.model(x)
175-
176-
177- # # Freeze encoder of model so that model can learn "aging" during the first epoch
178- # def freeze_encoders(G, F):
179- # for param in G.encoder.parameters():
180- # param.requires_grad = False
181- # for param in F.encoder.parameters():
182- # param.requires_grad = False
183-
184-
185- # # Unfreeze encoders later
186- # def unfreeze_encoders(G, F):
187- # for param in G.encoder.parameters():
188- # param.requires_grad = True
189- # for param in F.encoder.parameters():
190- # param.requires_grad = True
138+ # x: (B, 3, H, W)
139+ x = self .model (x ) # (B, 1, H//8-2, W//8-2)
140+ # Average pooling and flatten
141+ return F .avg_pool2d (x , x .size ()[2 :]).view (
142+ x .size ()[0 ], - 1
143+ ) # global average -> (B, 1, 1, 1) -> flatten to (B, 1)
191144
192145
193146# Initialize and return the generators and discriminators used for training
@@ -197,20 +150,6 @@ def initialize_models(
197150 n_blocks : int = 9 ,
198151) -> tuple [Generator , Generator , Discriminator , Discriminator ]:
199152 """Instantiate generators and discriminators with default sizes."""
200- # G = smp.Unet(
201- # encoder_name="resnet34",
202- # encoder_weights="imagenet", # preload low-level filters
203- # in_channels=3, # RGB input
204- # classes=3, # RGB output
205- # )
206-
207- # F = smp.Unet(
208- # encoder_name="resnet34",
209- # encoder_weights="imagenet", # preload low-level filters
210- # in_channels=3, # RGB input
211- # classes=3, # RGB output
212- # )
213-
214153 # initialize the generators and discriminators
215154 G = Generator (ngf , n_blocks )
216155 F = Generator (ngf , n_blocks )
0 commit comments