6161
6262
6363@torch .no_grad () 
64- def  log_validation (
65-     vae , args , accelerator , weight_dtype , step , is_final_validation = False 
66- ):
64+ def  log_validation (vae , args , accelerator , weight_dtype , step , is_final_validation = False ):
6765    logger .info ("Running validation... " )
6866
6967    if  not  is_final_validation :
@@ -91,23 +89,18 @@ def log_validation(
9189        with  inference_ctx :
9290            reconstructions  =  vae (targets ).sample 
9391
94-         images .append (
95-             torch .cat ([targets .cpu (), reconstructions .cpu ()], axis = 0 )
96-         )
92+         images .append (torch .cat ([targets .cpu (), reconstructions .cpu ()], axis = 0 ))
9793
9894    tracker_key  =  "test"  if  is_final_validation  else  "validation" 
9995    for  tracker  in  accelerator .trackers :
10096        if  tracker .name  ==  "tensorboard" :
10197            np_images  =  np .stack ([np .asarray (img ) for  img  in  images ])
102-             tracker .writer .add_images (
103-                 f"{ tracker_key }  , np_images , step 
104-             )
98+             tracker .writer .add_images (f"{ tracker_key }  , np_images , step )
10599        elif  tracker .name  ==  "wandb" :
106100            tracker .log (
107101                {
108102                    f"{ tracker_key }  : [
109-                         wandb .Image (torchvision .utils .make_grid (image ))
110-                         for  _ , image  in  enumerate (images )
103+                         wandb .Image (torchvision .utils .make_grid (image )) for  _ , image  in  enumerate (images )
111104                    ]
112105                }
113106            )
@@ -677,7 +670,9 @@ def load_model_hook(models, input_dir):
677670
678671                # pop models so that they are not loaded again 
679672                model  =  models .pop ()
680-                 load_model  =  NLayerDiscriminator (input_nc = 3 , n_layers = 3 , use_actnorm = False ).load_state_dict (os .path .join (input_dir , "discriminator" , "pytorch_model.bin" ))
673+                 load_model  =  NLayerDiscriminator (input_nc = 3 , n_layers = 3 , use_actnorm = False ).load_state_dict (
674+                     os .path .join (input_dir , "discriminator" , "pytorch_model.bin" )
675+                 )
681676                model .load_state_dict (load_model .state_dict ())
682677                del  load_model 
683678
@@ -690,7 +685,6 @@ def load_model_hook(models, input_dir):
690685        accelerator .register_save_state_pre_hook (save_model_hook )
691686        accelerator .register_load_state_pre_hook (load_model_hook )
692687
693- 
694688    vae .requires_grad_ (True )
695689    if  args .decoder_only :
696690        vae .encoder .requires_grad_ (False )
@@ -723,9 +717,7 @@ def load_model_hook(models, input_dir):
723717    )
724718
725719    if  unwrap_model (vae ).dtype  !=  torch .float32 :
726-         raise  ValueError (
727-             f"VAE loaded as datatype { unwrap_model (vae ).dtype } { low_precision_error_string }  
728-         )
720+         raise  ValueError (f"VAE loaded as datatype { unwrap_model (vae ).dtype } { low_precision_error_string }  )
729721
730722    # Enable TF32 for faster training on Ampere GPUs, 
731723    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 
@@ -802,7 +794,15 @@ def load_model_hook(models, input_dir):
802794    )
803795
804796    # Prepare everything with our `accelerator`. 
805-     vae , discriminator , optimizer , disc_optimizer , train_dataloader , lr_scheduler , disc_lr_scheduler  =  accelerator .prepare (
797+     (
798+         vae ,
799+         discriminator ,
800+         optimizer ,
801+         disc_optimizer ,
802+         train_dataloader ,
803+         lr_scheduler ,
804+         disc_lr_scheduler ,
805+     ) =  accelerator .prepare (
806806        vae , discriminator , optimizer , disc_optimizer , train_dataloader , lr_scheduler , disc_lr_scheduler 
807807    )
808808
@@ -935,7 +935,7 @@ def load_model_hook(models, input_dir):
935935                        "disc_weight" : disc_weight .detach ().mean ().item (),
936936                        "disc_factor" : disc_factor ,
937937                        "g_loss" : g_loss .detach ().mean ().item (),
938-                         "lr" : lr_scheduler .get_last_lr ()[0 ]
938+                         "lr" : lr_scheduler .get_last_lr ()[0 ], 
939939                    }
940940
941941                    accelerator .backward (loss )
@@ -956,7 +956,7 @@ def load_model_hook(models, input_dir):
956956                        "disc_loss" : disc_loss .detach ().mean ().item (),
957957                        "logits_real" : logits_real .detach ().mean ().item (),
958958                        "logits_fake" : logits_fake .detach ().mean ().item (),
959-                         "disc_lr" : disc_lr_scheduler .get_last_lr ()[0 ]
959+                         "disc_lr" : disc_lr_scheduler .get_last_lr ()[0 ], 
960960                    }
961961            # Checks if the accelerator has performed an optimization step behind the scenes 
962962            if  accelerator .sync_gradients :
0 commit comments