@@ -140,7 +140,7 @@ def save_model_card(
140140 model_card = load_or_create_model_card (
141141 repo_id_or_path = repo_id ,
142142 from_training = True ,
143- license = "openrail++ " ,
143+ license = "other " ,
144144 base_model = base_model ,
145145 prompt = instance_prompt ,
146146 model_description = model_description ,
@@ -186,7 +186,7 @@ def log_validation(
186186 f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
187187 f" { args .validation_prompt } ."
188188 )
189- pipeline = pipeline .to (accelerator .device , dtype = torch_dtype )
189+ pipeline = pipeline .to (accelerator .device )
190190 pipeline .set_progress_bar_config (disable = True )
191191
192192 # run inference
@@ -608,6 +608,12 @@ def parse_args(input_args=None):
608608 " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
609609 ),
610610 )
611+ parser .add_argument (
612+ "--cache_latents" ,
613+ action = "store_true" ,
614+ default = False ,
615+ help = "Cache the VAE latents" ,
616+ )
611617 parser .add_argument (
612618 "--report_to" ,
613619 type = str ,
@@ -628,6 +634,15 @@ def parse_args(input_args=None):
628634 " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
629635 ),
630636 )
637+ parser .add_argument (
638+ "--upcast_before_saving" ,
639+ action = "store_true" ,
640+ default = False ,
641+ help = (
642+ "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
643+ "Defaults to precision dtype used for training to save memory"
644+ ),
645+ )
631646 parser .add_argument (
632647 "--prior_generation_precision" ,
633648 type = str ,
@@ -1394,6 +1409,16 @@ def load_model_hook(models, input_dir):
13941409 logger .warning (
13951410 "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
13961411 )
1412+ if args .train_text_encoder and args .text_encoder_lr :
1413+ logger .warning (
1414+ f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:"
1415+ f" { args .text_encoder_lr } and learning_rate: { args .learning_rate } . "
1416+ f"When using prodigy only learning_rate is used as the initial learning rate."
1417+ )
1418+ # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
1419+ # --learning_rate
1420+ params_to_optimize [1 ]["lr" ] = args .learning_rate
1421+ params_to_optimize [2 ]["lr" ] = args .learning_rate
13971422
13981423 optimizer = optimizer_class (
13991424 params_to_optimize ,
@@ -1440,6 +1465,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14401465 pooled_prompt_embeds = pooled_prompt_embeds .to (accelerator .device )
14411466 return prompt_embeds , pooled_prompt_embeds
14421467
1468+ # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
1469+ # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
1470+ # the redundant encoding.
14431471 if not args .train_text_encoder and not train_dataset .custom_instance_prompts :
14441472 instance_prompt_hidden_states , instance_pooled_prompt_embeds = compute_text_embeddings (
14451473 args .instance_prompt , text_encoders , tokenizers
@@ -1484,6 +1512,21 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14841512 tokens_two = torch .cat ([tokens_two , class_tokens_two ], dim = 0 )
14851513 tokens_three = torch .cat ([tokens_three , class_tokens_three ], dim = 0 )
14861514
1515+ vae_config_shift_factor = vae .config .shift_factor
1516+ vae_config_scaling_factor = vae .config .scaling_factor
1517+ if args .cache_latents :
1518+ latents_cache = []
1519+ for batch in tqdm (train_dataloader , desc = "Caching latents" ):
1520+ with torch .no_grad ():
1521+ batch ["pixel_values" ] = batch ["pixel_values" ].to (
1522+ accelerator .device , non_blocking = True , dtype = weight_dtype
1523+ )
1524+ latents_cache .append (vae .encode (batch ["pixel_values" ]).latent_dist )
1525+
1526+ if args .validation_prompt is None :
1527+ del vae
1528+ free_memory ()
1529+
14871530 # Scheduler and math around the number of training steps.
14881531 overrode_max_train_steps = False
14891532 num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
@@ -1500,7 +1543,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
15001543 power = args .lr_power ,
15011544 )
15021545
1503- # Prepare everything with our `accelerator`.
15041546 # Prepare everything with our `accelerator`.
15051547 if args .train_text_encoder :
15061548 (
@@ -1607,8 +1649,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16071649
16081650 for step , batch in enumerate (train_dataloader ):
16091651 models_to_accumulate = [transformer ]
1652+ if args .train_text_encoder :
1653+ models_to_accumulate .extend ([text_encoder_one , text_encoder_two ])
16101654 with accelerator .accumulate (models_to_accumulate ):
1611- pixel_values = batch ["pixel_values" ].to (dtype = vae .dtype )
16121655 prompts = batch ["prompts" ]
16131656
16141657 # encode batch prompts when custom prompts are provided for each image -
@@ -1639,8 +1682,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16391682 )
16401683
16411684 # Convert images to latent space
1642- model_input = vae .encode (pixel_values ).latent_dist .sample ()
1643- model_input = (model_input - vae .config .shift_factor ) * vae .config .scaling_factor
1685+ if args .cache_latents :
1686+ model_input = latents_cache [step ].sample ()
1687+ else :
1688+ pixel_values = batch ["pixel_values" ].to (dtype = vae .dtype )
1689+ model_input = vae .encode (pixel_values ).latent_dist .sample ()
1690+
1691+ model_input = (model_input - vae_config_shift_factor ) * vae_config_scaling_factor
16441692 model_input = model_input .to (dtype = weight_dtype )
16451693
16461694 # Sample noise that we'll add to the latents
@@ -1773,6 +1821,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17731821 text_encoder_one , text_encoder_two , text_encoder_three = load_text_encoders (
17741822 text_encoder_cls_one , text_encoder_cls_two , text_encoder_cls_three
17751823 )
1824+ text_encoder_one .to (weight_dtype )
1825+ text_encoder_two .to (weight_dtype )
17761826 pipeline = StableDiffusion3Pipeline .from_pretrained (
17771827 args .pretrained_model_name_or_path ,
17781828 vae = vae ,
@@ -1793,15 +1843,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17931843 epoch = epoch ,
17941844 torch_dtype = weight_dtype ,
17951845 )
1796-
1797- del text_encoder_one , text_encoder_two , text_encoder_three
1798- free_memory ()
1846+ if not args . train_text_encoder :
1847+ del text_encoder_one , text_encoder_two , text_encoder_three
1848+ free_memory ()
17991849
18001850 # Save the lora layers
18011851 accelerator .wait_for_everyone ()
18021852 if accelerator .is_main_process :
18031853 transformer = unwrap_model (transformer )
1804- transformer = transformer .to (torch .float32 )
1854+ if args .upcast_before_saving :
1855+ transformer .to (torch .float32 )
1856+ else :
1857+ transformer = transformer .to (weight_dtype )
18051858 transformer_lora_layers = get_peft_model_state_dict (transformer )
18061859
18071860 if args .train_text_encoder :
0 commit comments