Skip to content

Commit 1eb9d45

Browse files
committed
fixed train.py to reflect model changes
1 parent dd8f9e9 commit 1eb9d45

File tree

1 file changed

+53
-12
lines changed

1 file changed

+53
-12
lines changed

src/aging_gan/train.py

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,23 @@ def parse_args() -> argparse.Namespace:
5959
default=32,
6060
help="Batch size per device during evaluation.",
6161
)
62+
p.add_argument(
63+
"--lambda_cyc_value",
64+
type=int,
65+
default=7,
66+
help="Weight for cyclical loss",
67+
)
68+
p.add_argument(
69+
"--lambda_id_value",
70+
type=int,
71+
default=7,
72+
help="Weight for identity loss",
73+
)
74+
p.add_argument(
75+
"--weight_decay",
76+
type=int,
77+
default=1e-4,
78+
)
6279

6380
# other params
6481
p.add_argument(
@@ -127,15 +144,39 @@ def initialize_optimizers(cfg, G, F, DX, DY):
127144
# track all generator params (even frozen encoder params during initial training).
128145
# This would allow us to transition easily to the full fine-tuning later on by simply toggling requires_grad=True
129146
# since the optimizers already track all the parameters from the start.
130-
opt_G = optim.Adam(G.parameters(), lr=cfg.gen_lr, betas=(0.5, 0.999), fused=True)
131-
opt_F = optim.Adam(F.parameters(), lr=cfg.gen_lr, betas=(0.5, 0.999), fused=True)
132-
opt_DX = optim.Adam(DX.parameters(), lr=cfg.disc_lr, betas=(0.5, 0.999), fused=True)
133-
opt_DY = optim.Adam(DY.parameters(), lr=cfg.disc_lr, betas=(0.5, 0.999), fused=True)
147+
opt_G = optim.Adam(
148+
G.parameters(),
149+
lr=cfg.gen_lr,
150+
betas=(0.5, 0.999),
151+
fused=True,
152+
weight_decay=cfg.weight_decay,
153+
)
154+
opt_F = optim.Adam(
155+
F.parameters(),
156+
lr=cfg.gen_lr,
157+
betas=(0.5, 0.999),
158+
fused=True,
159+
weight_decay=cfg.weight_decay,
160+
)
161+
opt_DX = optim.Adam(
162+
DX.parameters(),
163+
lr=cfg.disc_lr,
164+
betas=(0.5, 0.999),
165+
fused=True,
166+
weight_decay=cfg.weight_decay,
167+
)
168+
opt_DY = optim.Adam(
169+
DY.parameters(),
170+
lr=cfg.disc_lr,
171+
betas=(0.5, 0.999),
172+
fused=True,
173+
weight_decay=cfg.weight_decay,
174+
)
134175

135176
return opt_G, opt_F, opt_DX, opt_DY
136177

137178

138-
def initialize_loss_functions(lambda_cyc_value: int = 2.0, lambda_id_value: int = 0.05):
179+
def initialize_loss_functions(lambda_cyc_value: int = 10, lambda_id_value: int = 5):
139180
mse = nn.MSELoss()
140181
l1 = nn.L1Loss()
141182
lambda_cyc = lambda_cyc_value
@@ -226,10 +267,9 @@ def perform_train_step(
226267
# Loss 2: cycle terms
227268
loss_cyc = lambda_cyc * (l1(rec_x, x) + l1(rec_y, y))
228269
# Loss 3: identity terms
229-
loss_id = lambda_id * 0.5 * (l1(G(y), y) + l1(F(x), x))
270+
loss_id = lambda_id * (l1(G(y), y) + l1(F(x), x))
230271
# Total loss
231-
loss_gan = 0.5 * (loss_g_adv + loss_f_adv)
232-
loss_gen_total = loss_gan + loss_cyc + loss_id
272+
loss_gen_total = loss_g_adv + loss_f_adv + loss_cyc + loss_id
233273

234274
# Backprop + grad norm + step
235275
accelerator.backward(loss_gen_total)
@@ -316,10 +356,9 @@ def evaluate_epoch(
316356
# Loss 2: cycle terms
317357
loss_cyc = lambda_cyc * (l1(rec_x, x) + l1(rec_y, y))
318358
# Loss 3: identity terms
319-
loss_id = lambda_id * 0.5 * (l1(G(y), y) + l1(F(x), x))
359+
loss_id = lambda_id * (l1(G(y), y) + l1(F(x), x))
320360
# Total loss
321-
loss_gen = 0.5 * (loss_g_adv + loss_f_adv)
322-
loss_gen_total = loss_gen + loss_cyc + loss_id
361+
loss_gen_total = loss_g_adv + loss_f_adv + loss_cyc + loss_id
323362
# FID metric (normalize to range of [0,1] from [-1,1])
324363
# FID expects float32 images, which can raise dtype warning for mixed precision batches unless converted.
325364
fid_metric.update((y * 0.5 + 0.5).float(), real=True)
@@ -540,7 +579,9 @@ def main() -> None:
540579
test_loader,
541580
)
542581
# Loss functions and scalers
543-
mse, l1, lambda_cyc, lambda_id = initialize_loss_functions()
582+
mse, l1, lambda_cyc, lambda_id = initialize_loss_functions(
583+
cfg.lambda_cyc_value, cfg.lambda_id_value
584+
)
544585
# Initialize schedulers (It it important this comes AFTER wrapping optimizers in accelerator)
545586
sched_G, sched_F, sched_DX, sched_DY = make_schedulers(
546587
cfg, opt_G, opt_F, opt_DX, opt_DY

0 commit comments

Comments
 (0)