@@ -140,7 +140,7 @@ def save_model_card(
140140    model_card  =  load_or_create_model_card (
141141        repo_id_or_path = repo_id ,
142142        from_training = True ,
143-         license = "openrail++ " ,
143+         license = "other " ,
144144        base_model = base_model ,
145145        prompt = instance_prompt ,
146146        model_description = model_description ,
@@ -186,7 +186,7 @@ def log_validation(
186186        f"Running validation... \n  Generating { args .num_validation_images }  
187187        f" { args .validation_prompt }  
188188    )
189-     pipeline  =  pipeline .to (accelerator .device ,  dtype = torch_dtype )
189+     pipeline  =  pipeline .to (accelerator .device )
190190    pipeline .set_progress_bar_config (disable = True )
191191
192192    # run inference 
@@ -608,6 +608,12 @@ def parse_args(input_args=None):
608608            " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 
609609        ),
610610    )
611+     parser .add_argument (
612+         "--cache_latents" ,
613+         action = "store_true" ,
614+         default = False ,
615+         help = "Cache the VAE latents" ,
616+     )
611617    parser .add_argument (
612618        "--report_to" ,
613619        type = str ,
@@ -628,6 +634,15 @@ def parse_args(input_args=None):
628634            " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 
629635        ),
630636    )
637+     parser .add_argument (
638+         "--upcast_before_saving" ,
639+         action = "store_true" ,
640+         default = False ,
641+         help = (
642+             "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). " 
643+             "Defaults to precision dtype used for training to save memory" 
644+         ),
645+     )
631646    parser .add_argument (
632647        "--prior_generation_precision" ,
633648        type = str ,
@@ -1394,6 +1409,16 @@ def load_model_hook(models, input_dir):
13941409            logger .warning (
13951410                "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" 
13961411            )
1412+         if  args .train_text_encoder  and  args .text_encoder_lr :
1413+             logger .warning (
1414+                 f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:" 
1415+                 f" { args .text_encoder_lr } { args .learning_rate }  
1416+                 f"When using prodigy only learning_rate is used as the initial learning rate." 
1417+             )
1418+             # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be 
1419+             # --learning_rate 
1420+             params_to_optimize [1 ]["lr" ] =  args .learning_rate 
1421+             params_to_optimize [2 ]["lr" ] =  args .learning_rate 
13971422
13981423        optimizer  =  optimizer_class (
13991424            params_to_optimize ,
@@ -1440,6 +1465,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14401465                pooled_prompt_embeds  =  pooled_prompt_embeds .to (accelerator .device )
14411466            return  prompt_embeds , pooled_prompt_embeds 
14421467
1468+     # If no type of tuning is done on the text_encoder and custom instance prompts are NOT 
1469+     # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid 
1470+     # the redundant encoding. 
14431471    if  not  args .train_text_encoder  and  not  train_dataset .custom_instance_prompts :
14441472        instance_prompt_hidden_states , instance_pooled_prompt_embeds  =  compute_text_embeddings (
14451473            args .instance_prompt , text_encoders , tokenizers 
@@ -1484,6 +1512,21 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14841512                tokens_two  =  torch .cat ([tokens_two , class_tokens_two ], dim = 0 )
14851513                tokens_three  =  torch .cat ([tokens_three , class_tokens_three ], dim = 0 )
14861514
1515+     vae_config_shift_factor  =  vae .config .shift_factor 
1516+     vae_config_scaling_factor  =  vae .config .scaling_factor 
1517+     if  args .cache_latents :
1518+         latents_cache  =  []
1519+         for  batch  in  tqdm (train_dataloader , desc = "Caching latents" ):
1520+             with  torch .no_grad ():
1521+                 batch ["pixel_values" ] =  batch ["pixel_values" ].to (
1522+                     accelerator .device , non_blocking = True , dtype = weight_dtype 
1523+                 )
1524+                 latents_cache .append (vae .encode (batch ["pixel_values" ]).latent_dist )
1525+ 
1526+         if  args .validation_prompt  is  None :
1527+             del  vae 
1528+             free_memory ()
1529+ 
14871530    # Scheduler and math around the number of training steps. 
14881531    overrode_max_train_steps  =  False 
14891532    num_update_steps_per_epoch  =  math .ceil (len (train_dataloader ) /  args .gradient_accumulation_steps )
@@ -1500,7 +1543,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
15001543        power = args .lr_power ,
15011544    )
15021545
1503-     # Prepare everything with our `accelerator`. 
15041546    # Prepare everything with our `accelerator`. 
15051547    if  args .train_text_encoder :
15061548        (
@@ -1607,8 +1649,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16071649
16081650        for  step , batch  in  enumerate (train_dataloader ):
16091651            models_to_accumulate  =  [transformer ]
1652+             if  args .train_text_encoder :
1653+                 models_to_accumulate .extend ([text_encoder_one , text_encoder_two ])
16101654            with  accelerator .accumulate (models_to_accumulate ):
1611-                 pixel_values  =  batch ["pixel_values" ].to (dtype = vae .dtype )
16121655                prompts  =  batch ["prompts" ]
16131656
16141657                # encode batch prompts when custom prompts are provided for each image - 
@@ -1639,8 +1682,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16391682                        )
16401683
16411684                # Convert images to latent space 
1642-                 model_input  =  vae .encode (pixel_values ).latent_dist .sample ()
1643-                 model_input  =  (model_input  -  vae .config .shift_factor ) *  vae .config .scaling_factor 
1685+                 if  args .cache_latents :
1686+                     model_input  =  latents_cache [step ].sample ()
1687+                 else :
1688+                     pixel_values  =  batch ["pixel_values" ].to (dtype = vae .dtype )
1689+                     model_input  =  vae .encode (pixel_values ).latent_dist .sample ()
1690+ 
1691+                 model_input  =  (model_input  -  vae_config_shift_factor ) *  vae_config_scaling_factor 
16441692                model_input  =  model_input .to (dtype = weight_dtype )
16451693
16461694                # Sample noise that we'll add to the latents 
@@ -1773,6 +1821,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17731821                    text_encoder_one , text_encoder_two , text_encoder_three  =  load_text_encoders (
17741822                        text_encoder_cls_one , text_encoder_cls_two , text_encoder_cls_three 
17751823                    )
1824+                     text_encoder_one .to (weight_dtype )
1825+                     text_encoder_two .to (weight_dtype )
17761826                pipeline  =  StableDiffusion3Pipeline .from_pretrained (
17771827                    args .pretrained_model_name_or_path ,
17781828                    vae = vae ,
@@ -1793,15 +1843,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17931843                    epoch = epoch ,
17941844                    torch_dtype = weight_dtype ,
17951845                )
1796- 
1797-                 del  text_encoder_one , text_encoder_two , text_encoder_three 
1798-                 free_memory ()
1846+                  if   not   args . train_text_encoder : 
1847+                      del  text_encoder_one , text_encoder_two , text_encoder_three 
1848+                      free_memory ()
17991849
18001850    # Save the lora layers 
18011851    accelerator .wait_for_everyone ()
18021852    if  accelerator .is_main_process :
18031853        transformer  =  unwrap_model (transformer )
1804-         transformer  =  transformer .to (torch .float32 )
1854+         if  args .upcast_before_saving :
1855+             transformer .to (torch .float32 )
1856+         else :
1857+             transformer  =  transformer .to (weight_dtype )
18051858        transformer_lora_layers  =  get_peft_model_state_dict (transformer )
18061859
18071860        if  args .train_text_encoder :
0 commit comments