1515from aging_gan .utils import (
1616 set_seed ,
1717 load_environ_vars ,
18- print_trainable_parameters ,
1918 save_checkpoint ,
2019 generate_and_save_samples ,
2120 get_device ,
2221)
2322from aging_gan .data import prepare_dataset
24- from aging_gan .model import initialize_models , freeze_encoders , unfreeze_encoders
23+ from aging_gan .model import initialize_models
2524from aging_gan .utils import archive_and_terminate
2625
2726logger = logging .getLogger (__name__ )
@@ -59,6 +58,12 @@ def parse_args() -> argparse.Namespace:
5958 default = 32 ,
6059 help = "Batch size per device during evaluation." ,
6160 )
61+ p .add_argument (
62+ "--lambda_adv_value" ,
63+ type = int ,
64+ default = 2 ,
65+ help = "Weight for adversarial loss" ,
66+ )
6267 p .add_argument (
6368 "--lambda_cyc_value" ,
6469 type = int ,
@@ -98,24 +103,6 @@ def parse_args() -> argparse.Namespace:
98103 default = 10 ,
99104 help = "The number of example generated images to save per epoch." ,
100105 )
101- p .add_argument (
102- "--train_size" ,
103- type = int ,
104- default = 3000 ,
105- help = "The size of train dataset to train on." ,
106- )
107- p .add_argument (
108- "--val_size" ,
109- type = int ,
110- default = 800 ,
111- help = "The size of validation dataset to evaluate." ,
112- )
113- p .add_argument (
114- "--test_size" ,
115- type = int ,
116- default = 800 ,
117- help = "The size of test dataset to evaluate." ,
118- )
119106 p .add_argument (
120107 "--num_workers" ,
121108 type = int ,
@@ -176,13 +163,16 @@ def initialize_optimizers(cfg, G, F, DX, DY):
176163 return opt_G , opt_F , opt_DX , opt_DY
177164
178165
179- def initialize_loss_functions (lambda_cyc_value : int = 10 , lambda_id_value : int = 5 ):
166+ def initialize_loss_functions (
167+ lambda_adv_value : int = 2 , lambda_cyc_value : int = 10 , lambda_id_value : int = 5
168+ ):
180169 mse = nn .MSELoss ()
181170 l1 = nn .L1Loss ()
171+ lambda_adv = lambda_adv_value
182172 lambda_cyc = lambda_cyc_value
183173 lambda_id = lambda_id_value
184174
185- return mse , l1 , lambda_cyc , lambda_id
175+ return mse , l1 , lambda_adv , lambda_cyc , lambda_id
186176
187177
188178def make_schedulers (cfg , opt_G , opt_F , opt_DX , opt_DY ):
@@ -212,6 +202,7 @@ def perform_train_step(
212202 real_data ,
213203 mse ,
214204 l1 ,
205+ lambda_adv ,
215206 lambda_cyc ,
216207 lambda_id , # loss functions and loss params
217208 opt_G ,
@@ -260,10 +251,10 @@ def perform_train_step(
260251 opt_F .zero_grad (set_to_none = True )
261252 # Loss 1: adversarial terms
262253 fake_test_logits = DX (fake_x ) # fake x logits
263- loss_f_adv = mse (fake_test_logits , torch .ones_like (fake_test_logits ))
254+ loss_f_adv = lambda_adv * mse (fake_test_logits , torch .ones_like (fake_test_logits ))
264255
265256 fake_test_logits = DY (fake_y ) # fake y logits
266- loss_g_adv = mse (fake_test_logits , torch .ones_like (fake_test_logits ))
257+ loss_g_adv = lambda_adv * mse (fake_test_logits , torch .ones_like (fake_test_logits ))
267258 # Loss 2: cycle terms
268259 loss_cyc = lambda_cyc * (l1 (rec_x , x ) + l1 (rec_y , y ))
269260 # Loss 3: identity terms
@@ -299,6 +290,7 @@ def evaluate_epoch(
299290 split : str , # either "val" or "test"
300291 mse ,
301292 l1 ,
293+ lambda_adv ,
302294 lambda_cyc ,
303295 lambda_id , # loss functions and loss params
304296 fid_metric ,
@@ -349,10 +341,14 @@ def evaluate_epoch(
349341 # ------ Evaluate Generators ------
350342 # Loss 1: adversarial terms
351343 fake_test_logits = DX (fake_x ) # fake x logits
352- loss_f_adv = mse (fake_test_logits , torch .ones_like (fake_test_logits ))
344+ loss_f_adv = lambda_adv * mse (
345+ fake_test_logits , torch .ones_like (fake_test_logits )
346+ )
353347
354348 fake_test_logits = DY (fake_y ) # fake y logits
355- loss_g_adv = mse (fake_test_logits , torch .ones_like (fake_test_logits ))
349+ loss_g_adv = lambda_adv * mse (
350+ fake_test_logits , torch .ones_like (fake_test_logits )
351+ )
356352 # Loss 2: cycle terms
357353 loss_cyc = lambda_cyc * (l1 (rec_x , x ) + l1 (rec_y , y ))
358354 # Loss 3: identity terms
@@ -396,6 +392,7 @@ def perform_epoch(
396392 DY ,
397393 mse ,
398394 l1 ,
395+ lambda_adv ,
399396 lambda_cyc ,
400397 lambda_id ,
401398 opt_G ,
@@ -427,6 +424,7 @@ def perform_epoch(
427424 real_data ,
428425 mse ,
429426 l1 ,
427+ lambda_adv ,
430428 lambda_cyc ,
431429 lambda_id , # loss functions and loss params
432430 opt_G ,
@@ -469,6 +467,7 @@ def perform_epoch(
469467 "val" ,
470468 mse ,
471469 l1 ,
470+ lambda_adv ,
472471 lambda_cyc ,
473472 lambda_id , # loss functions and loss params
474473 fid_metric , # evaluation metric
@@ -527,22 +526,19 @@ def main() -> None:
527526 cfg .train_batch_size ,
528527 cfg .eval_batch_size ,
529528 cfg .num_workers ,
530- train_size = cfg .train_size ,
531- val_size = cfg .val_size ,
532- test_size = cfg .test_size ,
533529 seed = cfg .seed ,
534530 )
535531
536532 # ---------- Models, Optimizers, Loss Functions, Schedulers Initialization ----------
537533 # Initialize the generators (G, F) and discriminators (DX, DY)
538534 G , F , DX , DY = initialize_models ()
539535 # Freeze generator encoderes for training during early epochs
540- logger .info ("Parameters of generator G:" )
541- logger .info (print_trainable_parameters (G ))
542- logger .info ("Freezing encoders of generators..." )
543- freeze_encoders (G , F )
544- logger .info ("Parameters of generator G after freezing:" )
545- logger .info (print_trainable_parameters (G ))
536+ # logger.info("Parameters of generator G:")
537+ # logger.info(print_trainable_parameters(G))
538+ # logger.info("Freezing encoders of generators...")
539+ # freeze_encoders(G, F)
540+ # logger.info("Parameters of generator G after freezing:")
541+ # logger.info(print_trainable_parameters(G))
546542 # Initialize optimizers
547543 (
548544 opt_G ,
@@ -579,8 +575,8 @@ def main() -> None:
579575 test_loader ,
580576 )
581577 # Loss functions and scalers
582- mse , l1 , lambda_cyc , lambda_id = initialize_loss_functions (
583- cfg .lambda_cyc_value , cfg .lambda_id_value
578+ mse , l1 , lambda_adv , lambda_cyc , lambda_id = initialize_loss_functions (
579+ cfg .lambda_adv_value , cfg . lambda_cyc_value , cfg .lambda_id_value
584580 )
585581 # Initialize schedulers (It it important this comes AFTER wrapping optimizers in accelerator)
586582 sched_G , sched_F , sched_DX , sched_DY = make_schedulers (
@@ -596,11 +592,11 @@ def main() -> None:
596592 for epoch in range (1 , cfg .num_train_epochs + 1 ):
597593 logger .info (f"\n EPOCH { epoch } " )
598594 # after 1 full epoch, unfreeze
599- if epoch == 2 :
600- logger .info ("Unfreezing encoders of generators..." )
601- unfreeze_encoders (G , F )
602- logger .info ("Parameters of generator G after unfreezing:" )
603- logger .info (print_trainable_parameters (G ))
595+ # if epoch == 2:
596+ # logger.info("Unfreezing encoders of generators...")
597+ # unfreeze_encoders(G, F)
598+ # logger.info("Parameters of generator G after unfreezing:")
599+ # logger.info(print_trainable_parameters(G))
604600
605601 val_metrics = perform_epoch (
606602 cfg ,
@@ -612,6 +608,7 @@ def main() -> None:
612608 DY ,
613609 mse ,
614610 l1 ,
611+ lambda_adv ,
615612 lambda_cyc ,
616613 lambda_id ,
617614 opt_G ,
0 commit comments