Skip to content

Commit bf0a41b

Browse files
author
Ubuntu
committed
returned to og model and removed compile"
1 parent 3e875c0 commit bf0a41b

File tree

2 files changed

+31
-124
lines changed

2 files changed

+31
-124
lines changed

src/aging_gan/model.py

Lines changed: 13 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,89 +1,6 @@
11
import torch.nn as nn
22
import torch.nn.utils as nn_utils
33
import segmentation_models_pytorch as smp
4-
import torch.nn.functional as F
5-
6-
# ------------------------------------------------------------
7-
# 9‑residual‑block ResNet generator (CycleGAN, 256×256)
8-
# ------------------------------------------------------------
9-
class ResnetBlock(nn.Module):
10-
def __init__(self, channels, padding_type="reflect"):
11-
super().__init__()
12-
pad = nn.ReflectionPad2d if padding_type == "reflect" else nn.ZeroPad2d
13-
14-
self.block = nn.Sequential(
15-
pad(1),
16-
nn.Conv2d(channels, channels, 3, bias=False),
17-
nn.InstanceNorm2d(channels, affine=True),
18-
nn.ReLU(),
19-
nn.Dropout(0.5),
20-
pad(1),
21-
nn.Conv2d(channels, channels, 3, bias=False),
22-
nn.InstanceNorm2d(channels, affine=True),
23-
)
24-
25-
def forward(self, x):
26-
return x + self.block(x) # residual add
27-
28-
29-
class ResnetGenerator(nn.Module):
30-
def __init__(self, in_c=3, out_c=3, n_blocks=9, ngf=64):
31-
super().__init__()
32-
assert n_blocks >= 1
33-
34-
layers = [
35-
nn.ReflectionPad2d(3),
36-
nn.Conv2d(in_c, ngf, 7, bias=False),
37-
nn.InstanceNorm2d(ngf, affine=True),
38-
nn.ReLU(),
39-
]
40-
41-
# downsample twice: 256→128→64 spatial, 64→128→256 channels
42-
mult = 1
43-
for _ in range(2):
44-
layers += [
45-
nn.Conv2d(ngf * mult, ngf * mult * 2, 3, 2, 1, bias=False),
46-
nn.InstanceNorm2d(ngf * mult * 2, affine=True),
47-
nn.ReLU(),
48-
]
49-
mult *= 2 # 1->2->4
50-
51-
# residual blocks
52-
layers += [ResnetBlock(ngf * mult) for _ in range(n_blocks)]
53-
54-
# upsample back to 256×256
55-
for _ in range(2):
56-
layers += [
57-
nn.ConvTranspose2d(
58-
ngf * mult, ngf * mult // 2,
59-
3, 2, 1, output_padding=1, bias=False
60-
),
61-
nn.InstanceNorm2d(ngf * mult // 2, affine=True),
62-
nn.ReLU(),
63-
]
64-
mult //= 2 # 4->2->1
65-
66-
layers += [
67-
nn.ReflectionPad2d(3),
68-
nn.Conv2d(ngf, out_c, 7), # bias=True is fine here
69-
nn.Tanh(),
70-
]
71-
self.model = nn.Sequential(*layers)
72-
73-
# weight init (Conv / ConvT)
74-
for m in self.modules():
75-
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
76-
nn.init.normal_(m.weight, 0.0, 0.02)
77-
# InstanceNorm affine params
78-
for m in self.modules():
79-
if isinstance(m, nn.InstanceNorm2d):
80-
nn.init.constant_(m.weight, 1.0)
81-
nn.init.constant_(m.bias, 0.0)
82-
83-
def forward(self, x):
84-
return self.model(x)
85-
86-
874

885
# Discriminator: PatchGAN 70x70
896
class PatchDiscriminator(nn.Module):
@@ -138,22 +55,19 @@ def unfreeze_encoders(G, F):
13855
# Initialize and return the generators and discriminators used for training
13956
def initialize_models():
14057
# initialize the generators
141-
# G = smp.Unet(
142-
# encoder_name="resnet34",
143-
# encoder_weights="imagenet", # preload low-level filters
144-
# in_channels=3, # RGB input
145-
# classes=3, # RGB output
146-
# )
147-
148-
# F = smp.Unet(
149-
# encoder_name="resnet34",
150-
# encoder_weights="imagenet", # preload low-level filters
151-
# in_channels=3, # RGB input
152-
# classes=3, # RGB output
153-
# )
154-
155-
G = ResnetGenerator()
156-
F = ResnetGenerator()
58+
G = smp.Unet(
59+
encoder_name="resnet34",
60+
encoder_weights="imagenet", # preload low-level filters
61+
in_channels=3, # RGB input
62+
classes=3, # RGB output
63+
)
64+
65+
F = smp.Unet(
66+
encoder_name="resnet34",
67+
encoder_weights="imagenet", # preload low-level filters
68+
in_channels=3, # RGB input
69+
classes=3, # RGB output
70+
)
15771

15872
# initlize the discriminator
15973
DX = PatchDiscriminator()

src/aging_gan/train.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
from aging_gan.utils import (
1616
set_seed,
1717
load_environ_vars,
18-
# print_trainable_parameters,
18+
print_trainable_parameters,
1919
save_checkpoint,
2020
generate_and_save_samples,
2121
get_device,
2222
)
2323
from aging_gan.data import prepare_dataset
24-
from aging_gan.model import initialize_models, ResnetGenerator # , freeze_encoders, unfreeze_encoders
24+
from aging_gan.model import initialize_models, freeze_encoders, unfreeze_encoders
2525
from aging_gan.utils import archive_and_terminate
2626

2727
logger = logging.getLogger(__name__)
@@ -45,18 +45,18 @@ def parse_args() -> argparse.Namespace:
4545
help="Initial learning rate for discriminators.",
4646
)
4747
p.add_argument(
48-
"--num_train_epochs", type=int, default=80, help="Number of training epochs."
48+
"--num_train_epochs", type=int, default=25, help="Number of training epochs."
4949
)
5050
p.add_argument(
5151
"--train_batch_size",
5252
type=int,
53-
default=8,
53+
default=16,
5454
help="Batch size per device during training.",
5555
)
5656
p.add_argument(
5757
"--eval_batch_size",
5858
type=int,
59-
default=16,
59+
default=32,
6060
help="Batch size per device during evaluation.",
6161
)
6262

@@ -135,7 +135,7 @@ def initialize_optimizers(cfg, G, F, DX, DY):
135135
return opt_G, opt_F, opt_DX, opt_DY
136136

137137

138-
def initialize_loss_functions(lambda_cyc_value: int = 10.0, lambda_id_value: int = 5.0):
138+
def initialize_loss_functions(lambda_cyc_value: int = 2.0, lambda_id_value: int = 0.05):
139139
mse = nn.MSELoss()
140140
l1 = nn.L1Loss()
141141
lambda_cyc = lambda_cyc_value
@@ -498,19 +498,12 @@ def main() -> None:
498498
# Initialize the generators (G, F) and discriminators (DX, DY)
499499
G, F, DX, DY = initialize_models()
500500
# Freeze generator encoderes for training during early epochs
501-
# logger.info("Parameters of generator G:")
502-
# logger.info(print_trainable_parameters(G))
503-
# logger.info("Freezing encoders of generators...")
504-
# freeze_encoders(G, F)
505-
# logger.info("Parameters of generator G after freezing:")
506-
# logger.info(print_trainable_parameters(G))
507-
508-
# Compile
509-
logger.info("Models compiling...")
510-
G = torch.compile(G, backend="aot_eager", fullgraph=False, dynamic=True)
511-
F = torch.compile(F, backend="aot_eager", fullgraph=False, dynamic=True)
512-
DX = torch.compile(DX, backend="aot_eager", fullgraph=False, dynamic=True)
513-
DY = torch.compile(DY, backend="aot_eager", fullgraph=False, dynamic=True)
501+
logger.info("Parameters of generator G:")
502+
logger.info(print_trainable_parameters(G))
503+
logger.info("Freezing encoders of generators...")
504+
freeze_encoders(G, F)
505+
logger.info("Parameters of generator G after freezing:")
506+
logger.info(print_trainable_parameters(G))
514507
# Initialize optimizers
515508
(
516509
opt_G,
@@ -561,12 +554,12 @@ def main() -> None:
561554
best_fid = float("inf") # keep track of the best FID score for each epoch
562555
for epoch in range(1, cfg.num_train_epochs + 1):
563556
logger.info(f"\nEPOCH {epoch}")
564-
# # after 1 full epoch, unfreeze
565-
# if epoch == 2:
566-
# logger.info("Unfreezing encoders of generators...")
567-
# unfreeze_encoders(G, F)
568-
# logger.info("Parameters of generator G after unfreezing:")
569-
# logger.info(print_trainable_parameters(G))
557+
# after 1 full epoch, unfreeze
558+
if epoch == 2:
559+
logger.info("Unfreezing encoders of generators...")
560+
unfreeze_encoders(G, F)
561+
logger.info("Parameters of generator G after unfreezing:")
562+
logger.info(print_trainable_parameters(G))
570563

571564
val_metrics = perform_epoch(
572565
cfg,

0 commit comments

Comments
 (0)