2323from  huggingface_hub  import  create_repo , upload_folder 
2424from  packaging  import  version 
2525from  PIL  import  Image 
26- from  taming .modules .losses .vqperceptual  import  * 
26+ from  taming .modules .losses .vqperceptual  import  (
27+     hinge_d_loss , vanilla_d_loss , weights_init , NLayerDiscriminator 
28+ )
2729from  torchvision  import  transforms 
2830from  tqdm .auto  import  tqdm 
2931
3032import  diffusers 
3133from  diffusers  import  AutoencoderKL 
3234from  diffusers .optimization  import  get_scheduler 
35+ from  diffusers .training_utils  import  EMAModel 
3336from  diffusers .utils  import  check_min_version , is_wandb_available 
3437from  diffusers .utils .hub_utils  import  load_or_create_model_card , populate_model_card 
3538from  diffusers .utils .import_utils  import  is_xformers_available 
@@ -56,6 +59,7 @@ def image_grid(imgs, rows, cols):
5659    return  grid 
5760
5861
62+ @torch .no_grad () 
5963def  log_validation (
6064    vae , args , accelerator , weight_dtype , step , is_final_validation = False 
6165):
@@ -80,8 +84,8 @@ def log_validation(
8084
8185    for  i , validation_image  in  enumerate (args .validation_image ):
8286        validation_image  =  Image .open (validation_image ).convert ("RGB" )
83-         targets  =  image_transforms (validation_image ).to (weight_dtype )
84-         targets  =  targets .unsqueeze (0 ). to ( vae . device ) 
87+         targets  =  image_transforms (validation_image ).to (accelerator . device ,  weight_dtype )
88+         targets  =  targets .unsqueeze (0 )
8589
8690        with  inference_ctx :
8791            reconstructions  =  vae (targets ).sample 
@@ -112,15 +116,15 @@ def log_validation(
112116        gc .collect ()
113117        torch .cuda .empty_cache ()
114118
115-          return  images 
119+     return  images 
116120
117121
118122def  save_model_card (repo_id : str , images = None , base_model = str , repo_folder = None ):
119123    img_str  =  "" 
120124    if  images  is  not None :
121125        img_str  =  "You can find some example images below.\n \n " 
122-         image_grid (images , 1 , "example" ) .save (os .path .join (repo_folder , f"images_ { i }  ))
123-         img_str  +=  f") .save (os .path .join (repo_folder , f"images .png" ))
127+         img_str  +=  f"\n " 
124128
125129    model_description  =  f""" 
126130# autoencoderkl-{ repo_id }  
@@ -156,9 +160,14 @@ def parse_args(input_args=None):
156160        "--pretrained_model_name_or_path" ,
157161        type = str ,
158162        default = None ,
159-         required = True ,
160163        help = "Path to pretrained model or model identifier from huggingface.co/models." ,
161164    )
165+     parser .add_argument (
166+         "--model_config_name_or_path" ,
167+         type = str ,
168+         default = None ,
169+         help = "The config of the VAE model to train, leave as None to use standard VAE model configuration." ,
170+     )
162171    parser .add_argument (
163172        "--revision" ,
164173        type = str ,
@@ -242,6 +251,12 @@ def parse_args(input_args=None):
242251        default = 4.5e-6 ,
243252        help = "Initial learning rate (after the potential warmup period) to use." ,
244253    )
254+     parser .add_argument (
255+         "--disc_learning_rate" ,
256+         type = float ,
257+         default = 4.5e-6 ,
258+         help = "Initial learning rate (after the potential warmup period) to use." ,
259+     )
245260    parser .add_argument (
246261        "--scale_lr" ,
247262        action = "store_true" ,
@@ -257,6 +272,15 @@ def parse_args(input_args=None):
257272            ' "constant", "constant_with_warmup"]' 
258273        ),
259274    )
275+     parser .add_argument (
276+         "--disc_lr_scheduler" ,
277+         type = str ,
278+         default = "constant" ,
279+         help = (
280+             'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 
281+             ' "constant", "constant_with_warmup"]' 
282+         ),
283+     )
260284    parser .add_argument (
261285        "--lr_warmup_steps" , type = int , default = 500 , help = "Number of steps for the warmup in the lr scheduler." 
262286    )
@@ -270,6 +294,7 @@ def parse_args(input_args=None):
270294    parser .add_argument (
271295        "--use_8bit_adam" , action = "store_true" , help = "Whether or not to use 8-bit Adam from bitsandbytes." 
272296    )
297+     parser .add_argument ("--use_ema" , action = "store_true" , help = "Whether to use EMA model." )
273298    parser .add_argument (
274299        "--dataloader_num_workers" ,
275300        type = int ,
@@ -417,7 +442,7 @@ def parse_args(input_args=None):
417442        help = "Scaling factor for the Kullback-Leibler divergence penalty term." ,
418443    )
419444    parser .add_argument (
420-         "--lpips_scale " ,
445+         "--perceptual_scale " ,
421446        type = float ,
422447        default = 0.5 ,
423448        help = "Scaling factor for the LPIPS metric" ,
@@ -440,6 +465,12 @@ def parse_args(input_args=None):
440465        default = 1.0 ,
441466        help = "Scaling factor for the discriminator" ,
442467    )
468+     parser .add_argument (
469+         "--disc_loss" ,
470+         type = str ,
471+         default = "hinge" ,
472+         help = "Loss function for the discriminator" ,
473+     )
443474    parser .add_argument (
444475        "--decoder_only" ,
445476        action = "store_true" ,
@@ -587,19 +618,28 @@ def main(args):
587618            ).repo_id 
588619
589620    # Load AutoencoderKL 
590-     vae  =  AutoencoderKL .from_pretrained (
591-         args .pretrained_model_name_or_path , revision = args .revision 
592-     )
593-     lpips_loss_fn  =  lpips .LPIPS (net = "vgg" )
594-     discriminator  =  NLayerDiscriminator (
595-         input_nc = 3 , n_layers = 3 , use_actnorm = False ,
596-     ).apply (weights_init )
621+     if  args .pretrained_model_name_or_path  is  None  and  args .model_config_name_or_path  is  None :
622+         config  =  AutoencoderKL .load_config ("stabilityai/sd-vae-ft-mse" )
623+         vae  =  AutoencoderKL .from_config (config )
624+     elif  args .pretrained_model_name_or_path  is  not None :
625+         vae  =  AutoencoderKL .from_pretrained (args .pretrained_model_name_or_path , revision = args .revision )
626+     else :
627+         config  =  AutoencoderKL .load_config (args .model_config_name_or_path )
628+         vae  =  AutoencoderKL .from_config (config )
629+     if  args .use_ema :
630+         ema_vae  =  EMAModel (vae .parameters (), model_cls = AutoencoderKL , model_config = vae .config )
631+     perceptual_loss  =  lpips .LPIPS (net = "vgg" ).eval ()
632+     discriminator  =  NLayerDiscriminator (input_nc = 3 , n_layers = 3 , use_actnorm = False ).apply (weights_init )
597633
598634    # `accelerate` 0.16.0 will have better support for customized saving 
599635    if  version .parse (accelerate .__version__ ) >=  version .parse ("0.16.0" ):
600636        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format 
601637        def  save_model_hook (models , weights , output_dir ):
602638            if  accelerator .is_main_process :
639+                 if  args .use_ema :
640+                     sub_dir  =  "autoencoderkl_ema" 
641+                     ema_vae .save_pretrained (os .path .join (output_dir , sub_dir ))
642+ 
603643                i  =  len (weights ) -  1 
604644
605645                while  len (weights ) >  0 :
@@ -618,13 +658,22 @@ def save_model_hook(models, weights, output_dir):
618658
619659        def  load_model_hook (models , input_dir ):
620660            while  len (models ) >  0 :
661+                 if  args .use_ema :
662+                     sub_dir  =  "autoencoderkl_ema" 
663+                     load_model  =  EMAModel .from_pretrained (os .path .join (input_dir , sub_dir ), AutoencoderKL )
664+                     ema_vae .load_state_dict (load_model .state_dict ())
665+                     ema_vae .to (accelerator .device )
666+                     del  load_model 
667+ 
621668                # pop models so that they are not loaded again 
622669                model  =  models .pop ()
623- 
624-                 # load diffusers style into model 
670+                 load_model  =  NLayerDiscriminator (input_nc = 3 , n_layers = 3 , use_actnorm = False ).load_state_dict (os .path .join (input_dir , "discriminator" , "pytorch_model.bin" ))
671+                 model .load_state_dict (load_model .state_dict ())
672+                 del  load_model 
673+                 
674+                 model  =  models .pop ()
625675                load_model  =  AutoencoderKL .from_pretrained (input_dir , subfolder = "autoencoderkl" )
626676                model .register_to_config (** load_model .config )
627- 
628677                model .load_state_dict (load_model .state_dict ())
629678                del  load_model 
630679
@@ -638,7 +687,6 @@ def load_model_hook(models, input_dir):
638687        if  getattr (vae , "quant_conv" , None ):
639688            vae .quant_conv .requires_grad_ (False )
640689    vae .train ()
641-     lpips_loss_fn .requires_grad_ (False )
642690    discriminator .requires_grad_ (True )
643691    discriminator .train ()
644692
@@ -688,17 +736,17 @@ def load_model_hook(models, input_dir):
688736        optimizer_class  =  torch .optim .AdamW 
689737
690738    params_to_optimize  =  filter (lambda  p : p .requires_grad , vae .parameters ())
691-     params_to_optimize_2  =  filter (lambda  p : p .requires_grad , discriminator .parameters ())
739+     disc_params_to_optimize  =  filter (lambda  p : p .requires_grad , discriminator .parameters ())
692740    optimizer  =  optimizer_class (
693741        params_to_optimize ,
694742        lr = args .learning_rate ,
695743        betas = (args .adam_beta1 , args .adam_beta2 ),
696744        weight_decay = args .adam_weight_decay ,
697745        eps = args .adam_epsilon ,
698746    )
699-     optimizer_2  =  optimizer_class (
700-         params_to_optimize_2 ,
701-         lr = args .learning_rate ,
747+     disc_optimizer  =  optimizer_class (
748+         disc_params_to_optimize ,
749+         lr = args .disc_learning_rate ,
702750        betas = (args .adam_beta1 , args .adam_beta2 ),
703751        weight_decay = args .adam_weight_decay ,
704752        eps = args .adam_epsilon ,
@@ -729,10 +777,18 @@ def load_model_hook(models, input_dir):
729777        num_cycles = args .lr_num_cycles ,
730778        power = args .lr_power ,
731779    )
780+     disc_lr_scheduler  =  get_scheduler (
781+         args .disc_lr_scheduler ,
782+         optimizer = disc_optimizer ,
783+         num_warmup_steps = args .lr_warmup_steps  *  accelerator .num_processes ,
784+         num_training_steps = args .max_train_steps  *  accelerator .num_processes ,
785+         num_cycles = args .lr_num_cycles ,
786+         power = args .lr_power ,
787+     )
732788
733789    # Prepare everything with our `accelerator`. 
734-     vae , discriminator , optimizer , optimizer_2 , train_dataloader , lr_scheduler  =  accelerator .prepare (
735-         vae , discriminator , optimizer , optimizer_2 , train_dataloader , lr_scheduler 
790+     vae , discriminator , optimizer , disc_optimizer , train_dataloader , lr_scheduler ,  disc_lr_scheduler  =  accelerator .prepare (
791+         vae , discriminator , optimizer , disc_optimizer , train_dataloader , lr_scheduler ,  disc_lr_scheduler 
736792    )
737793
738794    # For mixed precision training we cast the text_encoder and vae weights to half-precision 
@@ -743,10 +799,12 @@ def load_model_hook(models, input_dir):
743799    elif  accelerator .mixed_precision  ==  "bf16" :
744800        weight_dtype  =  torch .bfloat16 
745801
746-     # Move vae  to device and cast to weight_dtype 
802+     # Move VAE, perceptual loss and discriminator  to device and cast to weight_dtype 
747803    vae .to (accelerator .device , dtype = weight_dtype )
748-     lpips_loss_fn .to (accelerator .device , dtype = weight_dtype )
804+     perceptual_loss .to (accelerator .device , dtype = weight_dtype )
749805    discriminator .to (accelerator .device , dtype = weight_dtype )
806+     if  args .use_ema :
807+         ema_vae .to (accelerator .device , dtype = weight_dtype )
750808
751809    # We need to recalculate our total training steps as the size of the training dataloader may have changed. 
752810    num_update_steps_per_epoch  =  math .ceil (len (train_dataloader ) /  args .gradient_accumulation_steps )
@@ -812,6 +870,8 @@ def load_model_hook(models, input_dir):
812870
813871    image_logs  =  None 
814872    for  epoch  in  range (first_epoch , args .num_train_epochs ):
873+         vae .train ()
874+         discriminator .train ()
815875        for  step , batch  in  enumerate (train_dataloader ):
816876            # Convert images to latent space and reconstruct from them 
817877            targets  =  batch ["pixel_values" ].to (dtype = weight_dtype )
@@ -834,9 +894,9 @@ def load_model_hook(models, input_dir):
834894                        rec_loss  =  F .l1_loss (reconstructions .float (), targets .float (), reduction = "none" )
835895                    # perceptual loss. The high level feature mean squared error loss 
836896                    with  torch .no_grad ():
837-                         lpips_loss  =  lpips_loss_fn (reconstructions , targets )
897+                         p_loss  =  perceptual_loss (reconstructions , targets )
838898
839-                     rec_loss  =  rec_loss  +  args .lpips_scale  *  lpips_loss 
899+                     rec_loss  =  rec_loss  +  args .perceptual_scale  *  p_loss 
840900                    nll_loss  =  rec_loss 
841901                    nll_loss  =  torch .sum (nll_loss ) /  nll_loss .shape [0 ]
842902
@@ -859,10 +919,10 @@ def load_model_hook(models, input_dir):
859919                        "loss" : loss .detach ().mean ().item (),
860920                        "nll_loss" : nll_loss .detach ().mean ().item (),
861921                        "rec_loss" : rec_loss .detach ().mean ().item (),
862-                         "lpips_loss " : lpips_loss .detach ().mean ().item (),
922+                         "p_loss " : p_loss .detach ().mean ().item (),
863923                        "kl_loss" : kl_loss .detach ().mean ().item (),
864924                        "disc_weight" : disc_weight .detach ().mean ().item (),
865-                         "disc_factor" : torch . tensor ( disc_factor ) ,
925+                         "disc_factor" : disc_factor ,
866926                        "g_loss" : g_loss .detach ().mean ().item (),
867927                        "lr" : lr_scheduler .get_last_lr ()[0 ]
868928                    }
@@ -878,18 +938,21 @@ def load_model_hook(models, input_dir):
878938                with  accelerator .accumulate (discriminator ):
879939                    logits_real  =  discriminator (targets )
880940                    logits_fake  =  discriminator (reconstructions )
881-                     disc_loss  =  hinge_d_loss 
941+                     disc_loss  =  hinge_d_loss   if   args . disc_loss   ==   "hinge"   else   vanilla_d_loss 
882942                    disc_factor  =  args .disc_factor  if  global_step  >=  args .disc_start  else  0.0 
883943                    disc_loss  =  disc_factor  *  disc_loss (logits_real , logits_fake )
884944                    logs  =  {
885945                        "disc_loss" : disc_loss .detach ().mean ().item (),
886946                        "logits_real" : logits_real .detach ().mean ().item (),
887947                        "logits_fake" : logits_fake .detach ().mean ().item (),
948+                         "disc_lr" : disc_lr_scheduler .get_last_lr ()[0 ]
888949                    }
889950            # Checks if the accelerator has performed an optimization step behind the scenes 
890951            if  accelerator .sync_gradients :
891952                progress_bar .update (1 )
892953                global_step  +=  1 
954+                 if  args .use_ema :
955+                     ema_vae .step (vae .parameters ())
893956
894957                if  accelerator .is_main_process :
895958                    if  global_step  %  args .checkpointing_steps  ==  0 :
@@ -918,13 +981,18 @@ def load_model_hook(models, input_dir):
918981                        logger .info (f"Saved state to { save_path }  )
919982
920983                    if  global_step  ==  1  or  global_step  %  args .validation_steps  ==  0 :
984+                         if  args .use_ema :
985+                             ema_vae .store (vae .parameters ())
986+                             ema_vae .copy_to (vae .parameters ())
921987                        image_logs  =  log_validation (
922988                            vae ,
923989                            args ,
924990                            accelerator ,
925991                            weight_dtype ,
926992                            global_step ,
927993                        )
994+                         if  args .use_ema :
995+                             ema_vae .restore (vae .parameters ())
928996
929997            progress_bar .set_postfix (** logs )
930998            accelerator .log (logs , step = global_step )
@@ -936,8 +1004,11 @@ def load_model_hook(models, input_dir):
9361004    accelerator .wait_for_everyone ()
9371005    if  accelerator .is_main_process :
9381006        vae  =  accelerator .unwrap_model (vae )
1007+         discriminator  =  accelerator .unwrap_model (discriminator )
1008+         if  args .use_ema :
1009+             ema_vae .copy_to (vae .parameters ())
9391010        vae .save_pretrained (args .output_dir )
940- 
1011+          torch . save ( discriminator . state_dict (),  os . path . join ( args . output_dir ,  "pytorch_model.bin" )) 
9411012        # Run a final round of validation. 
9421013        image_logs  =  None 
9431014        image_logs  =  log_validation (
0 commit comments