@@ -59,6 +59,23 @@ def parse_args() -> argparse.Namespace:
5959 default = 32 ,
6060 help = "Batch size per device during evaluation." ,
6161 )
62+ p .add_argument (
63+ "--lambda_cyc_value" ,
64+ type = int ,
65+ default = 7 ,
66+ help = "Weight for cyclical loss" ,
67+ )
68+ p .add_argument (
69+ "--lambda_id_value" ,
70+ type = int ,
71+ default = 7 ,
72+ help = "Weight for identity loss" ,
73+ )
74+ p .add_argument (
75+ "--weight_decay" ,
76+ type = int ,
77+ default = 1e-4 ,
78+ )
6279
6380 # other params
6481 p .add_argument (
@@ -127,15 +144,39 @@ def initialize_optimizers(cfg, G, F, DX, DY):
127144 # track all generator params (even frozen encoder params during initial training).
128145 # This would allow us to transition easily to the full fine-tuning later on by simply toggling requires_grad=True
129146 # since the optimizers already track all the parameters from the start.
130- opt_G = optim .Adam (G .parameters (), lr = cfg .gen_lr , betas = (0.5 , 0.999 ), fused = True )
131- opt_F = optim .Adam (F .parameters (), lr = cfg .gen_lr , betas = (0.5 , 0.999 ), fused = True )
132- opt_DX = optim .Adam (DX .parameters (), lr = cfg .disc_lr , betas = (0.5 , 0.999 ), fused = True )
133- opt_DY = optim .Adam (DY .parameters (), lr = cfg .disc_lr , betas = (0.5 , 0.999 ), fused = True )
147+ opt_G = optim .Adam (
148+ G .parameters (),
149+ lr = cfg .gen_lr ,
150+ betas = (0.5 , 0.999 ),
151+ fused = True ,
152+ weight_decay = cfg .weight_decay ,
153+ )
154+ opt_F = optim .Adam (
155+ F .parameters (),
156+ lr = cfg .gen_lr ,
157+ betas = (0.5 , 0.999 ),
158+ fused = True ,
159+ weight_decay = cfg .weight_decay ,
160+ )
161+ opt_DX = optim .Adam (
162+ DX .parameters (),
163+ lr = cfg .disc_lr ,
164+ betas = (0.5 , 0.999 ),
165+ fused = True ,
166+ weight_decay = cfg .weight_decay ,
167+ )
168+ opt_DY = optim .Adam (
169+ DY .parameters (),
170+ lr = cfg .disc_lr ,
171+ betas = (0.5 , 0.999 ),
172+ fused = True ,
173+ weight_decay = cfg .weight_decay ,
174+ )
134175
135176 return opt_G , opt_F , opt_DX , opt_DY
136177
137178
138- def initialize_loss_functions (lambda_cyc_value : int = 2.0 , lambda_id_value : int = 0.05 ):
179+ def initialize_loss_functions (lambda_cyc_value : int = 10 , lambda_id_value : int = 5 ):
139180 mse = nn .MSELoss ()
140181 l1 = nn .L1Loss ()
141182 lambda_cyc = lambda_cyc_value
@@ -226,10 +267,9 @@ def perform_train_step(
226267 # Loss 2: cycle terms
227268 loss_cyc = lambda_cyc * (l1 (rec_x , x ) + l1 (rec_y , y ))
228269 # Loss 3: identity terms
229- loss_id = lambda_id * 0.5 * (l1 (G (y ), y ) + l1 (F (x ), x ))
270+ loss_id = lambda_id * (l1 (G (y ), y ) + l1 (F (x ), x ))
230271 # Total loss
231- loss_gan = 0.5 * (loss_g_adv + loss_f_adv )
232- loss_gen_total = loss_gan + loss_cyc + loss_id
272+ loss_gen_total = loss_g_adv + loss_f_adv + loss_cyc + loss_id
233273
234274 # Backprop + grad norm + step
235275 accelerator .backward (loss_gen_total )
@@ -316,10 +356,9 @@ def evaluate_epoch(
316356 # Loss 2: cycle terms
317357 loss_cyc = lambda_cyc * (l1 (rec_x , x ) + l1 (rec_y , y ))
318358 # Loss 3: identity terms
319- loss_id = lambda_id * 0.5 * (l1 (G (y ), y ) + l1 (F (x ), x ))
359+ loss_id = lambda_id * (l1 (G (y ), y ) + l1 (F (x ), x ))
320360 # Total loss
321- loss_gen = 0.5 * (loss_g_adv + loss_f_adv )
322- loss_gen_total = loss_gen + loss_cyc + loss_id
361+ loss_gen_total = loss_g_adv + loss_f_adv + loss_cyc + loss_id
323362 # FID metric (normalize to range of [0,1] from [-1,1])
324363 # FID expects float32 images, which can raise dtype warning for mixed precision batches unless converted.
325364 fid_metric .update ((y * 0.5 + 0.5 ).float (), real = True )
@@ -540,7 +579,9 @@ def main() -> None:
540579 test_loader ,
541580 )
542581 # Loss functions and scalers
543- mse , l1 , lambda_cyc , lambda_id = initialize_loss_functions ()
582+ mse , l1 , lambda_cyc , lambda_id = initialize_loss_functions (
583+ cfg .lambda_cyc_value , cfg .lambda_id_value
584+ )
544585 # Initialize schedulers (It it important this comes AFTER wrapping optimizers in accelerator)
545586 sched_G , sched_F , sched_DX , sched_DY = make_schedulers (
546587 cfg , opt_G , opt_F , opt_DX , opt_DY
0 commit comments