Skip to content

Commit fcd7d18

Browse files
committed
compute gen first and then disc for efficiency
1 parent b9741fe commit fcd7d18

File tree

1 file changed

+69
-67
lines changed

1 file changed

+69
-67
lines changed

src/aging_gan/train.py

Lines changed: 69 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def parse_args() -> argparse.Namespace:
7878
)
7979
p.add_argument(
8080
"--weight_decay",
81-
type=int,
81+
type=float,
8282
default=1e-4,
8383
)
8484

@@ -135,28 +135,24 @@ def initialize_optimizers(cfg, G, F, DX, DY):
135135
G.parameters(),
136136
lr=cfg.gen_lr,
137137
betas=(0.5, 0.999),
138-
fused=True,
139138
weight_decay=cfg.weight_decay,
140139
)
141140
opt_F = optim.Adam(
142141
F.parameters(),
143142
lr=cfg.gen_lr,
144143
betas=(0.5, 0.999),
145-
fused=True,
146144
weight_decay=cfg.weight_decay,
147145
)
148146
opt_DX = optim.Adam(
149147
DX.parameters(),
150148
lr=cfg.disc_lr,
151149
betas=(0.5, 0.999),
152-
fused=True,
153150
weight_decay=cfg.weight_decay,
154151
)
155152
opt_DY = optim.Adam(
156153
DY.parameters(),
157154
lr=cfg.disc_lr,
158155
betas=(0.5, 0.999),
159-
fused=True,
160156
weight_decay=cfg.weight_decay,
161157
)
162158

@@ -212,64 +208,68 @@ def perform_train_step(
212208
accelerator,
213209
):
214210
x, y = real_data
215-
# Generate fakes and reconstrucitons
216-
fake_x = F(y)
217-
fake_y = G(x)
218-
rec_x = F(fake_y)
219-
rec_y = G(fake_x)
211+
# ------ Update Generators ------
212+
opt_G.zero_grad(set_to_none=True)
213+
opt_F.zero_grad(set_to_none=True)
214+
with accelerator.autocast():
215+
fake_x = F(y)
216+
fake_y = G(x)
217+
rec_x = F(fake_y)
218+
rec_y = G(fake_x)
219+
# Loss 1: adversarial terms
220+
fake_test_logits = DX(fake_x) # fake x logits
221+
loss_f_adv = lambda_adv * mse(
222+
fake_test_logits, torch.ones_like(fake_test_logits)
223+
)
224+
fake_test_logits = DY(fake_y) # fake y logits
225+
loss_g_adv = lambda_adv * mse(
226+
fake_test_logits, torch.ones_like(fake_test_logits)
227+
)
228+
# Loss 2: cycle terms
229+
loss_cyc = lambda_cyc * (l1(rec_x, x) + l1(rec_y, y))
230+
# Loss 3: identity terms
231+
loss_id = lambda_id * (l1(G(y), y) + l1(F(x), x))
232+
# Total loss
233+
loss_gen_total = loss_g_adv + loss_f_adv + loss_cyc + loss_id
234+
# Backprop + grad norm + step
235+
accelerator.backward(loss_gen_total)
236+
accelerator.clip_grad_norm_(
237+
list(G.parameters()) + list(F.parameters()), max_norm=1.0
238+
)
239+
opt_G.step()
240+
opt_F.step()
220241

221242
# ------ Update Discriminators ------
222243
# DX: real young vs fake young
223244
opt_DX.zero_grad(set_to_none=True)
224-
real_logits = DX(x)
225-
real_loss = mse(real_logits, torch.ones_like(real_logits))
226-
fake_logits = DX(fake_x.detach())
227-
fake_loss = mse(fake_logits, torch.zeros_like(fake_logits))
228-
# DX loss + backprop + grad norm + step
229-
loss_DX = 0.5 * (real_loss + fake_loss)
245+
with accelerator.autocast():
246+
real_logits = DX(x)
247+
real_loss = mse(real_logits, torch.ones_like(real_logits))
248+
fake_logits = DX(fake_x.detach())
249+
fake_loss = mse(fake_logits, torch.zeros_like(fake_logits))
250+
# DX loss
251+
loss_DX = 0.5 * (real_loss + fake_loss)
252+
# backprop + grad norm + step
230253
accelerator.backward(loss_DX)
231254
accelerator.clip_grad_norm_(DX.parameters(), max_norm=1.0)
232255
opt_DX.step()
233256

234257
# DY: real old vs fake old
235258
opt_DY.zero_grad(set_to_none=True)
236-
real_logits = DY(y)
237-
real_loss = mse(real_logits, torch.ones_like(real_logits))
238-
fake_logits = DY(fake_y.detach())
239-
fake_loss = mse(fake_logits, torch.zeros_like(fake_logits))
240-
241-
# DY loss + backprop + grad norm + step
242-
loss_DY = 0.5 * (
243-
real_loss + fake_loss
244-
) # average loss to prevent discriminator learning "too quickly" compread to generators.
259+
with accelerator.autocast():
260+
real_logits = DY(y)
261+
real_loss = mse(real_logits, torch.ones_like(real_logits))
262+
fake_logits = DY(fake_y.detach())
263+
fake_loss = mse(fake_logits, torch.zeros_like(fake_logits))
264+
# DY loss
265+
loss_DY = 0.5 * (
266+
real_loss + fake_loss
267+
) # average loss to prevent discriminator learning "too quickly" compread to generators.
268+
# backprop + grad norm + step
245269
accelerator.backward(loss_DY)
246270
accelerator.clip_grad_norm_(DY.parameters(), max_norm=1.0)
247271
opt_DY.step()
248272

249-
# ------ Update Generators ------
250-
opt_G.zero_grad(set_to_none=True)
251-
opt_F.zero_grad(set_to_none=True)
252-
# Loss 1: adversarial terms
253-
fake_test_logits = DX(fake_x) # fake x logits
254-
loss_f_adv = lambda_adv * mse(fake_test_logits, torch.ones_like(fake_test_logits))
255-
256-
fake_test_logits = DY(fake_y) # fake y logits
257-
loss_g_adv = lambda_adv * mse(fake_test_logits, torch.ones_like(fake_test_logits))
258-
# Loss 2: cycle terms
259-
loss_cyc = lambda_cyc * (l1(rec_x, x) + l1(rec_y, y))
260-
# Loss 3: identity terms
261-
loss_id = lambda_id * (l1(G(y), y) + l1(F(x), x))
262-
# Total loss
263-
loss_gen_total = loss_g_adv + loss_f_adv + loss_cyc + loss_id
264-
265-
# Backprop + grad norm + step
266-
accelerator.backward(loss_gen_total)
267-
accelerator.clip_grad_norm_(
268-
list(G.parameters()) + list(F.parameters()), max_norm=1.0
269-
)
270-
opt_G.step()
271-
opt_F.step()
272-
273273
return {
274274
"train/loss_DX": loss_DX.item(),
275275
"train/loss_DY": loss_DY.item(),
@@ -294,6 +294,7 @@ def evaluate_epoch(
294294
lambda_cyc,
295295
lambda_id, # loss functions and loss params
296296
fid_metric,
297+
accelerator,
297298
):
298299
metrics = {
299300
f"{split}/loss_DX": 0.0,
@@ -307,14 +308,31 @@ def evaluate_epoch(
307308
}
308309
n_batches = 0
309310

310-
with torch.no_grad():
311+
with torch.no_grad(), accelerator.autocast():
311312
fid_metric.reset()
312313
for x, y in tqdm(loader):
313314
# Forward: Generate fakes and reconstrucitons
314315
fake_x = F(y)
315316
fake_y = G(x)
316317
rec_x = F(fake_y)
317318
rec_y = G(fake_x)
319+
# ------ Evaluate Generators ------
320+
# Loss 1: adversarial terms
321+
fake_test_logits = DX(fake_x) # fake x logits
322+
loss_f_adv = lambda_adv * mse(
323+
fake_test_logits, torch.ones_like(fake_test_logits)
324+
)
325+
326+
fake_test_logits = DY(fake_y) # fake y logits
327+
loss_g_adv = lambda_adv * mse(
328+
fake_test_logits, torch.ones_like(fake_test_logits)
329+
)
330+
# Loss 2: cycle terms
331+
loss_cyc = lambda_cyc * (l1(rec_x, x) + l1(rec_y, y))
332+
# Loss 3: identity terms
333+
loss_id = lambda_id * (l1(G(y), y) + l1(F(x), x))
334+
# Total loss
335+
loss_gen_total = loss_g_adv + loss_f_adv + loss_cyc + loss_id
318336

319337
# ------ Evaluate Discriminators ------
320338
# DX: real young vs fake young
@@ -338,23 +356,6 @@ def evaluate_epoch(
338356
real_loss + fake_loss
339357
) # average loss to prevent discriminator learning "too quickly" compread to generators.
340358

341-
# ------ Evaluate Generators ------
342-
# Loss 1: adversarial terms
343-
fake_test_logits = DX(fake_x) # fake x logits
344-
loss_f_adv = lambda_adv * mse(
345-
fake_test_logits, torch.ones_like(fake_test_logits)
346-
)
347-
348-
fake_test_logits = DY(fake_y) # fake y logits
349-
loss_g_adv = lambda_adv * mse(
350-
fake_test_logits, torch.ones_like(fake_test_logits)
351-
)
352-
# Loss 2: cycle terms
353-
loss_cyc = lambda_cyc * (l1(rec_x, x) + l1(rec_y, y))
354-
# Loss 3: identity terms
355-
loss_id = lambda_id * (l1(G(y), y) + l1(F(x), x))
356-
# Total loss
357-
loss_gen_total = loss_g_adv + loss_f_adv + loss_cyc + loss_id
358359
# FID metric (normalize to range of [0,1] from [-1,1])
359360
# FID expects float32 images, which can raise dtype warning for mixed precision batches unless converted.
360361
fid_metric.update((y * 0.5 + 0.5).float(), real=True)
@@ -471,6 +472,7 @@ def perform_epoch(
471472
lambda_cyc,
472473
lambda_id, # loss functions and loss params
473474
fid_metric, # evaluation metric
475+
accelerator,
474476
)
475477
logger.info(
476478
f"val/loss_DX: {val_metrics['val/loss_DX']:.4f} | val/loss_DY: {val_metrics['val/loss_DY']:.4f} | val/fid_val: {val_metrics['val/fid_val']:.4f} | val/loss_gen_total: {val_metrics['val/loss_gen_total']:.4f} | val/loss_g_adv: {val_metrics['val/loss_g_adv']:.4f} | val/loss_f_adv: {val_metrics['val/loss_f_adv']:.4f} | val/loss_cyc: {val_metrics['val/loss_cyc']:.4f} | val/loss_id: {val_metrics['val/loss_id']:.4f}"

0 commit comments

Comments
 (0)