1919import  logging 
2020import  math 
2121import  os 
22- import  random 
2322import  shutil 
2423from  pathlib  import  Path 
2524
2625import  accelerate 
27- import  numpy  as  np 
2826import  lpips 
27+ import  numpy  as  np 
2928import  torch 
3029import  torch .nn .functional  as  F 
3130import  torch .utils .checkpoint 
3837from  huggingface_hub  import  create_repo , upload_folder 
3938from  packaging  import  version 
4039from  PIL  import  Image 
41- from  taming .modules .losses .vqperceptual  import  (
42-     hinge_d_loss , vanilla_d_loss , weights_init , NLayerDiscriminator 
43- )
40+ from  taming .modules .losses .vqperceptual  import  NLayerDiscriminator , hinge_d_loss , vanilla_d_loss , weights_init 
4441from  torchvision  import  transforms 
4542from  tqdm .auto  import  tqdm 
4643
@@ -93,22 +90,22 @@ def log_validation(
9390
9491        with  inference_ctx :
9592            reconstructions  =  vae (targets ).sample 
96-          
93+ 
9794        images .append (
9895            torch .cat ([targets .cpu (), reconstructions .cpu ()], axis = 0 )
9996        )
100-      
97+ 
10198    tracker_key  =  "test"  if  is_final_validation  else  "validation" 
10299    for  tracker  in  accelerator .trackers :
103100        if  tracker .name  ==  "tensorboard" :
104101            np_images  =  np .stack ([np .asarray (img ) for  img  in  images ])
105102            tracker .writer .add_images (
106-                 " Original (left), Reconstruction (right)"np_images , step 
103+                 f" { tracker_key } :  Original (left), Reconstruction (right)"np_images , step 
107104            )
108105        elif  tracker .name  ==  "wandb" :
109106            tracker .log (
110107                {
111-                     " Original (left), Reconstruction (right)"
108+                     f" { tracker_key } :  Original (left), Reconstruction (right)"
112109                        wandb .Image (torchvision .utils .make_grid (image ))
113110                        for  _ , image  in  enumerate (images )
114111                    ]
@@ -127,8 +124,8 @@ def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None)
127124    img_str  =  "" 
128125    if  images  is  not None :
129126        img_str  =  "You can find some example images below.\n \n " 
130-         make_image_grid (images , 1 , len (images )).save (os .path .join (repo_folder , f "images.png"
131-         img_str  +=  f "\n "
127+         make_image_grid (images , 1 , len (images )).save (os .path .join (repo_folder , "images.png" ))
128+         img_str  +=  "\n " 
132129
133130    model_description  =  f""" 
134131# autoencoderkl-{ repo_id }  
@@ -529,7 +526,7 @@ def make_train_dataset(args, accelerator):
529526    # Preprocessing the datasets. 
530527    # We need to tokenize inputs and targets. 
531528    column_names  =  dataset ["train" ].column_names 
532-      
529+ 
533530    # 6. Get the column names for input/target. 
534531    if  args .image_column  is  None :
535532        image_column  =  column_names [0 ]
@@ -540,7 +537,7 @@ def make_train_dataset(args, accelerator):
540537            raise  ValueError (
541538                f"`--image_column` value '{ args .image_column } { ', ' .join (column_names )}  
542539            )
543-      
540+ 
544541    image_transforms  =  transforms .Compose (
545542        [
546543            transforms .Resize (args .resolution , interpolation = transforms .InterpolationMode .BILINEAR ),
@@ -580,7 +577,7 @@ def main(args):
580577            "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." 
581578            " Please use `huggingface-cli login` to authenticate with the Hub." 
582579        )
583-      
580+ 
584581    logging_dir  =  Path (args .output_dir , args .logging_dir )
585582
586583    accelerator_project_config  =  ProjectConfiguration (project_dir = args .output_dir , logging_dir = logging_dir )
@@ -591,7 +588,7 @@ def main(args):
591588        log_with = args .report_to ,
592589        project_config = accelerator_project_config ,
593590    )
594-      
591+ 
595592    # Disable AMP for MPS. 
596593    if  torch .backends .mps .is_available ():
597594        accelerator .native_amp  =  False 
@@ -623,7 +620,7 @@ def main(args):
623620            repo_id  =  create_repo (
624621                repo_id = args .hub_model_id  or  Path (args .output_dir ).name , exist_ok = True , token = args .hub_token 
625622            ).repo_id 
626-      
623+ 
627624    # Load AutoencoderKL 
628625    if  args .pretrained_model_name_or_path  is  None  and  args .model_config_name_or_path  is  None :
629626        config  =  AutoencoderKL .load_config ("stabilityai/sd-vae-ft-mse" )
@@ -637,7 +634,13 @@ def main(args):
637634        ema_vae  =  EMAModel (vae .parameters (), model_cls = AutoencoderKL , model_config = vae .config )
638635    perceptual_loss  =  lpips .LPIPS (net = "vgg" ).eval ()
639636    discriminator  =  NLayerDiscriminator (input_nc = 3 , n_layers = 3 , use_actnorm = False ).apply (weights_init )
640-     
637+ 
638+     # Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files) 
639+     def  unwrap_model (model ):
640+         model  =  accelerator .unwrap_model (model )
641+         model  =  model ._orig_mod  if  is_compiled_module (model ) else  model 
642+         return  model 
643+ 
641644    # `accelerate` 0.16.0 will have better support for customized saving 
642645    if  version .parse (accelerate .__version__ ) >=  version .parse ("0.16.0" ):
643646        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format 
@@ -677,7 +680,7 @@ def load_model_hook(models, input_dir):
677680                load_model  =  NLayerDiscriminator (input_nc = 3 , n_layers = 3 , use_actnorm = False ).load_state_dict (os .path .join (input_dir , "discriminator" , "pytorch_model.bin" ))
678681                model .load_state_dict (load_model .state_dict ())
679682                del  load_model 
680-                  
683+ 
681684                model  =  models .pop ()
682685                load_model  =  AutoencoderKL .from_pretrained (input_dir , subfolder = "autoencoderkl" )
683686                model .register_to_config (** load_model .config )
@@ -686,8 +689,8 @@ def load_model_hook(models, input_dir):
686689
687690        accelerator .register_save_state_pre_hook (save_model_hook )
688691        accelerator .register_load_state_pre_hook (load_model_hook )
689-      
690-      
692+ 
693+ 
691694    vae .requires_grad_ (True )
692695    if  args .decoder_only :
693696        vae .encoder .requires_grad_ (False )
@@ -696,7 +699,7 @@ def load_model_hook(models, input_dir):
696699    vae .train ()
697700    discriminator .requires_grad_ (True )
698701    discriminator .train ()
699-      
702+ 
700703    if  args .enable_xformers_memory_efficient_attention :
701704        if  is_xformers_available ():
702705            import  xformers 
@@ -709,16 +712,21 @@ def load_model_hook(models, input_dir):
709712            vae .enable_xformers_memory_efficient_attention ()
710713        else :
711714            raise  ValueError ("xformers is not available. Make sure it is installed correctly" )
712-      
715+ 
713716    if  args .gradient_checkpointing :
714717        vae .enable_gradient_checkpointing ()
715-      
718+ 
716719    # Check that all trainable models are in full precision 
717720    low_precision_error_string  =  (
718721        " Please make sure to always have all model weights in full float32 precision when starting training - even if" 
719722        " doing mixed precision training, copy of the weights should still be float32." 
720723    )
721-     
724+ 
725+     if  unwrap_model (vae ).dtype  !=  torch .float32 :
726+         raise  ValueError (
727+             f"VAE loaded as datatype { unwrap_model (vae ).dtype } { low_precision_error_string }  
728+         )
729+ 
722730    # Enable TF32 for faster training on Ampere GPUs, 
723731    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 
724732    if  args .allow_tf32 :
@@ -728,7 +736,7 @@ def load_model_hook(models, input_dir):
728736        args .learning_rate  =  (
729737            args .learning_rate  *  args .gradient_accumulation_steps  *  args .train_batch_size  *  accelerator .num_processes 
730738        )
731-      
739+ 
732740    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 
733741    if  args .use_8bit_adam :
734742        try :
@@ -741,7 +749,7 @@ def load_model_hook(models, input_dir):
741749        optimizer_class  =  bnb .optim .AdamW8bit 
742750    else :
743751        optimizer_class  =  torch .optim .AdamW 
744-      
752+ 
745753    params_to_optimize  =  filter (lambda  p : p .requires_grad , vae .parameters ())
746754    disc_params_to_optimize  =  filter (lambda  p : p .requires_grad , discriminator .parameters ())
747755    optimizer  =  optimizer_class (
@@ -760,22 +768,22 @@ def load_model_hook(models, input_dir):
760768    )
761769
762770    train_dataset  =  make_train_dataset (args , accelerator )
763-      
771+ 
764772    train_dataloader  =  torch .utils .data .DataLoader (
765773        train_dataset ,
766774        shuffle = True ,
767775        collate_fn = collate_fn ,
768776        batch_size = args .train_batch_size ,
769777        num_workers = args .dataloader_num_workers ,
770778    )
771-      
779+ 
772780    # Scheduler and math around the number of training steps. 
773781    overrode_max_train_steps  =  False 
774782    num_update_steps_per_epoch  =  math .ceil (len (train_dataloader ) /  args .gradient_accumulation_steps )
775783    if  args .max_train_steps  is  None :
776784        args .max_train_steps  =  args .num_train_epochs  *  num_update_steps_per_epoch 
777785        overrode_max_train_steps  =  True 
778-      
786+ 
779787    lr_scheduler  =  get_scheduler (
780788        args .lr_scheduler ,
781789        optimizer = optimizer ,
@@ -792,27 +800,27 @@ def load_model_hook(models, input_dir):
792800        num_cycles = args .lr_num_cycles ,
793801        power = args .lr_power ,
794802    )
795-      
803+ 
796804    # Prepare everything with our `accelerator`. 
797805    vae , discriminator , optimizer , disc_optimizer , train_dataloader , lr_scheduler , disc_lr_scheduler  =  accelerator .prepare (
798806        vae , discriminator , optimizer , disc_optimizer , train_dataloader , lr_scheduler , disc_lr_scheduler 
799807    )
800-      
808+ 
801809    # For mixed precision training we cast the text_encoder and vae weights to half-precision 
802810    # as these models are only used for inference, keeping weights in full precision is not required. 
803811    weight_dtype  =  torch .float32 
804812    if  accelerator .mixed_precision  ==  "fp16" :
805813        weight_dtype  =  torch .float16 
806814    elif  accelerator .mixed_precision  ==  "bf16" :
807815        weight_dtype  =  torch .bfloat16 
808-      
816+ 
809817    # Move VAE, perceptual loss and discriminator to device and cast to weight_dtype 
810818    vae .to (accelerator .device , dtype = weight_dtype )
811819    perceptual_loss .to (accelerator .device , dtype = weight_dtype )
812820    discriminator .to (accelerator .device , dtype = weight_dtype )
813821    if  args .use_ema :
814822        ema_vae .to (accelerator .device , dtype = weight_dtype )
815-      
823+ 
816824    # We need to recalculate our total training steps as the size of the training dataloader may have changed. 
817825    num_update_steps_per_epoch  =  math .ceil (len (train_dataloader ) /  args .gradient_accumulation_steps )
818826    if  overrode_max_train_steps :
@@ -850,7 +858,7 @@ def load_model_hook(models, input_dir):
850858            dirs  =  [d  for  d  in  dirs  if  d .startswith ("checkpoint" )]
851859            dirs  =  sorted (dirs , key = lambda  x : int (x .split ("-" )[1 ]))
852860            path  =  dirs [- 1 ] if  len (dirs ) >  0  else  None 
853-          
861+ 
854862        if  path  is  None :
855863            accelerator .print (
856864                f"Checkpoint '{ args .resume_from_checkpoint }  
@@ -866,7 +874,7 @@ def load_model_hook(models, input_dir):
866874            first_epoch  =  global_step  //  num_update_steps_per_epoch 
867875    else :
868876        initial_global_step  =  0 
869-      
877+ 
870878    progress_bar  =  tqdm (
871879        range (0 , args .max_train_steps ),
872880        initial = initial_global_step ,
@@ -898,7 +906,7 @@ def load_model_hook(models, input_dir):
898906                    # perceptual loss. The high level feature mean squared error loss 
899907                    with  torch .no_grad ():
900908                        p_loss  =  perceptual_loss (reconstructions , targets )
901-                  
909+ 
902910                    rec_loss  =  rec_loss  +  args .perceptual_scale  *  p_loss 
903911                    nll_loss  =  rec_loss 
904912                    nll_loss  =  torch .sum (nll_loss ) /  nll_loss .shape [0 ]
@@ -915,9 +923,9 @@ def load_model_hook(models, input_dir):
915923                    disc_weight  =  torch .clamp (disc_weight , 0.0 , 1e4 ).detach ()
916924                    disc_weight  =  disc_weight  *  args .disc_scale 
917925                    disc_factor  =  args .disc_factor  if  global_step  >=  args .disc_start  else  0.0 
918-                  
926+ 
919927                    loss  =  nll_loss  +  args .kl_scale  *  kl_loss  +  disc_weight  *  disc_factor  *  g_loss 
920-                      
928+ 
921929                    logs  =  {
922930                        "loss" : loss .detach ().mean ().item (),
923931                        "nll_loss" : nll_loss .detach ().mean ().item (),
@@ -929,7 +937,7 @@ def load_model_hook(models, input_dir):
929937                        "g_loss" : g_loss .detach ().mean ().item (),
930938                        "lr" : lr_scheduler .get_last_lr ()[0 ]
931939                    }
932-                  
940+ 
933941                    accelerator .backward (loss )
934942                    if  accelerator .sync_gradients :
935943                        params_to_clip  =  vae .parameters ()
@@ -1002,7 +1010,7 @@ def load_model_hook(models, input_dir):
10021010
10031011            if  global_step  >=  args .max_train_steps :
10041012                break 
1005-      
1013+ 
10061014    # Create the pipeline using using the trained modules and save it. 
10071015    accelerator .wait_for_everyone ()
10081016    if  accelerator .is_main_process :
@@ -1036,7 +1044,7 @@ def load_model_hook(models, input_dir):
10361044                commit_message = "End of training" ,
10371045                ignore_patterns = ["step_*" , "epoch_*" ],
10381046            )
1039-      
1047+ 
10401048    accelerator .end_training ()
10411049
10421050
0 commit comments