From 678364c309bdb3efe7e1481139cd8a6df1b41595 Mon Sep 17 00:00:00 2001 From: lavinal712 Date: Wed, 19 Mar 2025 05:12:33 +0000 Subject: [PATCH 1/2] add disc_optimizer step (not fix) --- .../autoencoderkl/train_autoencoderkl.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/examples/research_projects/autoencoderkl/train_autoencoderkl.py b/examples/research_projects/autoencoderkl/train_autoencoderkl.py index cf13ecdbf8ac..91ee58cfabc6 100644 --- a/examples/research_projects/autoencoderkl/train_autoencoderkl.py +++ b/examples/research_projects/autoencoderkl/train_autoencoderkl.py @@ -951,13 +951,20 @@ def load_model_hook(models, input_dir): logits_fake = discriminator(reconstructions) disc_loss = hinge_d_loss if args.disc_loss == "hinge" else vanilla_d_loss disc_factor = args.disc_factor if global_step >= args.disc_start else 0.0 - disc_loss = disc_factor * disc_loss(logits_real, logits_fake) + d_loss = disc_factor * disc_loss(logits_real, logits_fake) logs = { - "disc_loss": disc_loss.detach().mean().item(), + "disc_loss": d_loss.detach().mean().item(), "logits_real": logits_real.detach().mean().item(), "logits_fake": logits_fake.detach().mean().item(), "disc_lr": disc_lr_scheduler.get_last_lr()[0], } + accelerator.backward(d_loss) + if accelerator.sync_gradients: + params_to_clip = discriminator.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + disc_optimizer.step() + disc_lr_scheduler.step() + disc_optimizer.zero_grad(set_to_none=args.set_grads_to_none) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) From 2e87bfc5c1e5233e1e91d772c2600453c0ead082 Mon Sep 17 00:00:00 2001 From: lavinal712 Date: Wed, 19 Mar 2025 06:58:31 +0000 Subject: [PATCH 2/2] support syncbatchnorm in discriminator --- examples/research_projects/autoencoderkl/train_autoencoderkl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/research_projects/autoencoderkl/train_autoencoderkl.py b/examples/research_projects/autoencoderkl/train_autoencoderkl.py index 91ee58cfabc6..31cf8414ac10 100644 --- a/examples/research_projects/autoencoderkl/train_autoencoderkl.py +++ b/examples/research_projects/autoencoderkl/train_autoencoderkl.py @@ -627,6 +627,7 @@ def main(args): ema_vae = EMAModel(vae.parameters(), model_cls=AutoencoderKL, model_config=vae.config) perceptual_loss = lpips.LPIPS(net="vgg").eval() discriminator = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).apply(weights_init) + discriminator = torch.nn.SyncBatchNorm.convert_sync_batchnorm(discriminator) # Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files) def unwrap_model(model):