@@ -78,7 +78,7 @@ def parse_args() -> argparse.Namespace:
7878 )
7979 p .add_argument (
8080 "--weight_decay" ,
81- type = int ,
81+ type = float ,
8282 default = 1e-4 ,
8383 )
8484
@@ -135,28 +135,24 @@ def initialize_optimizers(cfg, G, F, DX, DY):
135135 G .parameters (),
136136 lr = cfg .gen_lr ,
137137 betas = (0.5 , 0.999 ),
138- fused = True ,
139138 weight_decay = cfg .weight_decay ,
140139 )
141140 opt_F = optim .Adam (
142141 F .parameters (),
143142 lr = cfg .gen_lr ,
144143 betas = (0.5 , 0.999 ),
145- fused = True ,
146144 weight_decay = cfg .weight_decay ,
147145 )
148146 opt_DX = optim .Adam (
149147 DX .parameters (),
150148 lr = cfg .disc_lr ,
151149 betas = (0.5 , 0.999 ),
152- fused = True ,
153150 weight_decay = cfg .weight_decay ,
154151 )
155152 opt_DY = optim .Adam (
156153 DY .parameters (),
157154 lr = cfg .disc_lr ,
158155 betas = (0.5 , 0.999 ),
159- fused = True ,
160156 weight_decay = cfg .weight_decay ,
161157 )
162158
@@ -212,64 +208,68 @@ def perform_train_step(
212208 accelerator ,
213209):
214210 x , y = real_data
215- # Generate fakes and reconstrucitons
216- fake_x = F (y )
217- fake_y = G (x )
218- rec_x = F (fake_y )
219- rec_y = G (fake_x )
211+ # ------ Update Generators ------
212+ opt_G .zero_grad (set_to_none = True )
213+ opt_F .zero_grad (set_to_none = True )
214+ with accelerator .autocast ():
215+ fake_x = F (y )
216+ fake_y = G (x )
217+ rec_x = F (fake_y )
218+ rec_y = G (fake_x )
219+ # Loss 1: adversarial terms
220+ fake_test_logits = DX (fake_x ) # fake x logits
221+ loss_f_adv = lambda_adv * mse (
222+ fake_test_logits , torch .ones_like (fake_test_logits )
223+ )
224+ fake_test_logits = DY (fake_y ) # fake y logits
225+ loss_g_adv = lambda_adv * mse (
226+ fake_test_logits , torch .ones_like (fake_test_logits )
227+ )
228+ # Loss 2: cycle terms
229+ loss_cyc = lambda_cyc * (l1 (rec_x , x ) + l1 (rec_y , y ))
230+ # Loss 3: identity terms
231+ loss_id = lambda_id * (l1 (G (y ), y ) + l1 (F (x ), x ))
232+ # Total loss
233+ loss_gen_total = loss_g_adv + loss_f_adv + loss_cyc + loss_id
234+ # Backprop + grad norm + step
235+ accelerator .backward (loss_gen_total )
236+ accelerator .clip_grad_norm_ (
237+ list (G .parameters ()) + list (F .parameters ()), max_norm = 1.0
238+ )
239+ opt_G .step ()
240+ opt_F .step ()
220241
221242 # ------ Update Discriminators ------
222243 # DX: real young vs fake young
223244 opt_DX .zero_grad (set_to_none = True )
224- real_logits = DX (x )
225- real_loss = mse (real_logits , torch .ones_like (real_logits ))
226- fake_logits = DX (fake_x .detach ())
227- fake_loss = mse (fake_logits , torch .zeros_like (fake_logits ))
228- # DX loss + backprop + grad norm + step
229- loss_DX = 0.5 * (real_loss + fake_loss )
245+ with accelerator .autocast ():
246+ real_logits = DX (x )
247+ real_loss = mse (real_logits , torch .ones_like (real_logits ))
248+ fake_logits = DX (fake_x .detach ())
249+ fake_loss = mse (fake_logits , torch .zeros_like (fake_logits ))
250+ # DX loss
251+ loss_DX = 0.5 * (real_loss + fake_loss )
252+ # backprop + grad norm + step
230253 accelerator .backward (loss_DX )
231254 accelerator .clip_grad_norm_ (DX .parameters (), max_norm = 1.0 )
232255 opt_DX .step ()
233256
234257 # DY: real old vs fake old
235258 opt_DY .zero_grad (set_to_none = True )
236- real_logits = DY (y )
237- real_loss = mse (real_logits , torch .ones_like (real_logits ))
238- fake_logits = DY (fake_y .detach ())
239- fake_loss = mse (fake_logits , torch .zeros_like (fake_logits ))
240-
241- # DY loss + backprop + grad norm + step
242- loss_DY = 0.5 * (
243- real_loss + fake_loss
244- ) # average loss to prevent discriminator learning "too quickly" compread to generators.
259+ with accelerator .autocast ():
260+ real_logits = DY (y )
261+ real_loss = mse (real_logits , torch .ones_like (real_logits ))
262+ fake_logits = DY (fake_y .detach ())
263+ fake_loss = mse (fake_logits , torch .zeros_like (fake_logits ))
264+ # DY loss
265+ loss_DY = 0.5 * (
266+ real_loss + fake_loss
267+ ) # average loss to prevent discriminator learning "too quickly" compread to generators.
268+ # backprop + grad norm + step
245269 accelerator .backward (loss_DY )
246270 accelerator .clip_grad_norm_ (DY .parameters (), max_norm = 1.0 )
247271 opt_DY .step ()
248272
249- # ------ Update Generators ------
250- opt_G .zero_grad (set_to_none = True )
251- opt_F .zero_grad (set_to_none = True )
252- # Loss 1: adversarial terms
253- fake_test_logits = DX (fake_x ) # fake x logits
254- loss_f_adv = lambda_adv * mse (fake_test_logits , torch .ones_like (fake_test_logits ))
255-
256- fake_test_logits = DY (fake_y ) # fake y logits
257- loss_g_adv = lambda_adv * mse (fake_test_logits , torch .ones_like (fake_test_logits ))
258- # Loss 2: cycle terms
259- loss_cyc = lambda_cyc * (l1 (rec_x , x ) + l1 (rec_y , y ))
260- # Loss 3: identity terms
261- loss_id = lambda_id * (l1 (G (y ), y ) + l1 (F (x ), x ))
262- # Total loss
263- loss_gen_total = loss_g_adv + loss_f_adv + loss_cyc + loss_id
264-
265- # Backprop + grad norm + step
266- accelerator .backward (loss_gen_total )
267- accelerator .clip_grad_norm_ (
268- list (G .parameters ()) + list (F .parameters ()), max_norm = 1.0
269- )
270- opt_G .step ()
271- opt_F .step ()
272-
273273 return {
274274 "train/loss_DX" : loss_DX .item (),
275275 "train/loss_DY" : loss_DY .item (),
@@ -294,6 +294,7 @@ def evaluate_epoch(
294294 lambda_cyc ,
295295 lambda_id , # loss functions and loss params
296296 fid_metric ,
297+ accelerator ,
297298):
298299 metrics = {
299300 f"{ split } /loss_DX" : 0.0 ,
@@ -307,14 +308,31 @@ def evaluate_epoch(
307308 }
308309 n_batches = 0
309310
310- with torch .no_grad ():
311+ with torch .no_grad (), accelerator . autocast () :
311312 fid_metric .reset ()
312313 for x , y in tqdm (loader ):
313314 # Forward: Generate fakes and reconstrucitons
314315 fake_x = F (y )
315316 fake_y = G (x )
316317 rec_x = F (fake_y )
317318 rec_y = G (fake_x )
319+ # ------ Evaluate Generators ------
320+ # Loss 1: adversarial terms
321+ fake_test_logits = DX (fake_x ) # fake x logits
322+ loss_f_adv = lambda_adv * mse (
323+ fake_test_logits , torch .ones_like (fake_test_logits )
324+ )
325+
326+ fake_test_logits = DY (fake_y ) # fake y logits
327+ loss_g_adv = lambda_adv * mse (
328+ fake_test_logits , torch .ones_like (fake_test_logits )
329+ )
330+ # Loss 2: cycle terms
331+ loss_cyc = lambda_cyc * (l1 (rec_x , x ) + l1 (rec_y , y ))
332+ # Loss 3: identity terms
333+ loss_id = lambda_id * (l1 (G (y ), y ) + l1 (F (x ), x ))
334+ # Total loss
335+ loss_gen_total = loss_g_adv + loss_f_adv + loss_cyc + loss_id
318336
319337 # ------ Evaluate Discriminators ------
320338 # DX: real young vs fake young
@@ -338,23 +356,6 @@ def evaluate_epoch(
338356 real_loss + fake_loss
339357 ) # average loss to prevent discriminator learning "too quickly" compread to generators.
340358
341- # ------ Evaluate Generators ------
342- # Loss 1: adversarial terms
343- fake_test_logits = DX (fake_x ) # fake x logits
344- loss_f_adv = lambda_adv * mse (
345- fake_test_logits , torch .ones_like (fake_test_logits )
346- )
347-
348- fake_test_logits = DY (fake_y ) # fake y logits
349- loss_g_adv = lambda_adv * mse (
350- fake_test_logits , torch .ones_like (fake_test_logits )
351- )
352- # Loss 2: cycle terms
353- loss_cyc = lambda_cyc * (l1 (rec_x , x ) + l1 (rec_y , y ))
354- # Loss 3: identity terms
355- loss_id = lambda_id * (l1 (G (y ), y ) + l1 (F (x ), x ))
356- # Total loss
357- loss_gen_total = loss_g_adv + loss_f_adv + loss_cyc + loss_id
358359 # FID metric (normalize to range of [0,1] from [-1,1])
359360 # FID expects float32 images, which can raise dtype warning for mixed precision batches unless converted.
360361 fid_metric .update ((y * 0.5 + 0.5 ).float (), real = True )
@@ -471,6 +472,7 @@ def perform_epoch(
471472 lambda_cyc ,
472473 lambda_id , # loss functions and loss params
473474 fid_metric , # evaluation metric
475+ accelerator ,
474476 )
475477 logger .info (
476478 f"val/loss_DX: { val_metrics ['val/loss_DX' ]:.4f} | val/loss_DY: { val_metrics ['val/loss_DY' ]:.4f} | val/fid_val: { val_metrics ['val/fid_val' ]:.4f} | val/loss_gen_total: { val_metrics ['val/loss_gen_total' ]:.4f} | val/loss_g_adv: { val_metrics ['val/loss_g_adv' ]:.4f} | val/loss_f_adv: { val_metrics ['val/loss_f_adv' ]:.4f} | val/loss_cyc: { val_metrics ['val/loss_cyc' ]:.4f} | val/loss_id: { val_metrics ['val/loss_id' ]:.4f} "
0 commit comments