Skip to content

Commit 68ba879

Browse files
committed
removed commented unused code and edited annotations
1 parent 8210454 commit 68ba879

File tree

5 files changed

+7
-98
lines changed

5 files changed

+7
-98
lines changed

README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
added spectral norm to path gan.
2-
learning reduces to 0.1
31
different lr for gen and disc
42

53

src/aging_gan/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __getitem__(self, idx: int) -> Tuple[Tensor, int]:
4646

4747
def make_unpaired_loader(
4848
root: str,
49-
split: str, # "train" | "valid" | "test"
49+
split: str,
5050
transform: T.Compose,
5151
batch_size: int = 4,
5252
num_workers: int = 1,

src/aging_gan/model.py

Lines changed: 6 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44
import torch.nn as nn
55
import torch.nn.functional as F
66

7-
# import torch.nn.utils as nn_utils
8-
# import segmentation_models_pytorch as smp
9-
107

118
class ResidualBlock(nn.Module):
129
"""Simple residual block with two conv layers."""
@@ -138,56 +135,12 @@ def __init__(self, ndf: int) -> None:
138135

139136
def forward(self, x: Tensor) -> Tensor:
140137
"""Return discriminator logits for input ``x``."""
141-
x = self.model(x)
142-
return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)
143-
144-
145-
# # Discriminator: PatchGAN 70x70
146-
# class PatchDiscriminator(nn.Module):
147-
# def __init__(self, in_channels=3, ndf=48):
148-
# super().__init__()
149-
# layers = [
150-
# nn_utils.spectral_norm(
151-
# nn.Conv2d(
152-
# in_channels=in_channels,
153-
# out_channels=ndf,
154-
# kernel_size=4,
155-
# stride=2,
156-
# padding=1,
157-
# )
158-
# ),
159-
# nn.LeakyReLU(0.2),
160-
# ]
161-
# nf = ndf
162-
# for i in range(3):
163-
# stride = 2 if i < 2 else 1
164-
# layers += [
165-
# nn_utils.spectral_norm(nn.Conv2d(nf, nf * 2, 4, stride, 1)),
166-
# nn.InstanceNorm2d(nf * 2, affine=True),
167-
# nn.LeakyReLU(0.2),
168-
# ]
169-
# nf *= 2
170-
# layers += [nn_utils.spectral_norm(nn.Conv2d(nf, 1, 4, 1, 1))]
171-
# self.model = nn.Sequential(*layers)
172-
173-
# def forward(self, x):
174-
# return self.model(x)
175-
176-
177-
# # Freeze encoder of model so that model can learn "aging" during the first epoch
178-
# def freeze_encoders(G, F):
179-
# for param in G.encoder.parameters():
180-
# param.requires_grad = False
181-
# for param in F.encoder.parameters():
182-
# param.requires_grad = False
183-
184-
185-
# # Unfreeze encoders later
186-
# def unfreeze_encoders(G, F):
187-
# for param in G.encoder.parameters():
188-
# param.requires_grad = True
189-
# for param in F.encoder.parameters():
190-
# param.requires_grad = True
138+
# x: (B, 3, H, W)
139+
x = self.model(x) # (B, 1, H//8-2, W//8-2)
140+
# Average pooling and flatten
141+
return F.avg_pool2d(x, x.size()[2:]).view(
142+
x.size()[0], -1
143+
) # global average -> (B, 1, 1, 1) -> flatten to (B, 1)
191144

192145

193146
# Initialize and return the generators and discriminators used for training
@@ -197,20 +150,6 @@ def initialize_models(
197150
n_blocks: int = 9,
198151
) -> tuple[Generator, Generator, Discriminator, Discriminator]:
199152
"""Instantiate generators and discriminators with default sizes."""
200-
# G = smp.Unet(
201-
# encoder_name="resnet34",
202-
# encoder_weights="imagenet", # preload low-level filters
203-
# in_channels=3, # RGB input
204-
# classes=3, # RGB output
205-
# )
206-
207-
# F = smp.Unet(
208-
# encoder_name="resnet34",
209-
# encoder_weights="imagenet", # preload low-level filters
210-
# in_channels=3, # RGB input
211-
# classes=3, # RGB output
212-
# )
213-
214153
# initialize the generators and discriminators
215154
G = Generator(ngf, n_blocks)
216155
F = Generator(ngf, n_blocks)

src/aging_gan/train.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -554,13 +554,6 @@ def main() -> None:
554554
# ---------- Models, Optimizers, Loss Functions, Schedulers Initialization ----------
555555
# Initialize the generators (G, F) and discriminators (DX, DY)
556556
G, F, DX, DY = initialize_models()
557-
# Freeze generator encoderes for training during early epochs
558-
# logger.info("Parameters of generator G:")
559-
# logger.info(print_trainable_parameters(G))
560-
# logger.info("Freezing encoders of generators...")
561-
# freeze_encoders(G, F)
562-
# logger.info("Parameters of generator G after freezing:")
563-
# logger.info(print_trainable_parameters(G))
564557
# Initialize optimizers
565558
(
566559
opt_G,
@@ -613,13 +606,6 @@ def main() -> None:
613606
best_fid = float("inf") # keep track of the best FID score for each epoch
614607
for epoch in range(1, cfg.num_train_epochs + 1):
615608
logger.info(f"\nEPOCH {epoch}")
616-
# after 1 full epoch, unfreeze
617-
# if epoch == 2:
618-
# logger.info("Unfreezing encoders of generators...")
619-
# unfreeze_encoders(G, F)
620-
# logger.info("Parameters of generator G after unfreezing:")
621-
# logger.info(print_trainable_parameters(G))
622-
623609
val_metrics = perform_epoch(
624610
cfg,
625611
train_loader,

src/aging_gan/utils.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,6 @@ def load_environ_vars(wandb_project: str = "aging-gan") -> None:
3939
logger.info(f"W&B project set to '{wandb_project}'")
4040

4141

42-
# def print_trainable_parameters(model) -> str:
43-
# """
44-
# Compute and return a summary of trainable vs. total parameters in a model.
45-
# """
46-
# trainable_params = 0
47-
# all_param = 0
48-
# for _, param in model.named_parameters():
49-
# all_param += param.numel()
50-
# if param.requires_grad:
51-
# trainable_params += param.numel()
52-
53-
# return f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
54-
55-
5642
def save_checkpoint(
5743
epoch,
5844
G,

0 commit comments

Comments
 (0)