@@ -627,6 +627,7 @@ def main(args):
627627 ema_vae = EMAModel (vae .parameters (), model_cls = AutoencoderKL , model_config = vae .config )
628628 perceptual_loss = lpips .LPIPS (net = "vgg" ).eval ()
629629 discriminator = NLayerDiscriminator (input_nc = 3 , n_layers = 3 , use_actnorm = False ).apply (weights_init )
630+ discriminator = torch .nn .SyncBatchNorm .convert_sync_batchnorm (discriminator )
630631
631632 # Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)
632633 def unwrap_model (model ):
@@ -951,13 +952,20 @@ def load_model_hook(models, input_dir):
951952 logits_fake = discriminator (reconstructions )
952953 disc_loss = hinge_d_loss if args .disc_loss == "hinge" else vanilla_d_loss
953954 disc_factor = args .disc_factor if global_step >= args .disc_start else 0.0
954- disc_loss = disc_factor * disc_loss (logits_real , logits_fake )
955+ d_loss = disc_factor * disc_loss (logits_real , logits_fake )
955956 logs = {
956- "disc_loss" : disc_loss .detach ().mean ().item (),
957+ "disc_loss" : d_loss .detach ().mean ().item (),
957958 "logits_real" : logits_real .detach ().mean ().item (),
958959 "logits_fake" : logits_fake .detach ().mean ().item (),
959960 "disc_lr" : disc_lr_scheduler .get_last_lr ()[0 ],
960961 }
962+ accelerator .backward (d_loss )
963+ if accelerator .sync_gradients :
964+ params_to_clip = discriminator .parameters ()
965+ accelerator .clip_grad_norm_ (params_to_clip , args .max_grad_norm )
966+ disc_optimizer .step ()
967+ disc_lr_scheduler .step ()
968+ disc_optimizer .zero_grad (set_to_none = args .set_grads_to_none )
961969 # Checks if the accelerator has performed an optimization step behind the scenes
962970 if accelerator .sync_gradients :
963971 progress_bar .update (1 )
0 commit comments