5454)
5555from diffusers .optimization import get_scheduler
5656from diffusers .training_utils import (
57+ _collate_lora_metadata ,
5758 cast_training_params ,
5859 compute_density_for_timestep_sampling ,
5960 compute_loss_weighting_for_sd3 ,
@@ -365,7 +366,12 @@ def parse_args(input_args=None):
365366 default = 4 ,
366367 help = ("The dimension of the LoRA update matrices." ),
367368 )
368-
369+ parser .add_argument (
370+ "--lora_alpha" ,
371+ type = int ,
372+ default = 4 ,
373+ help = "LoRA alpha to be used for additional scaling." ,
374+ )
369375 parser .add_argument ("--lora_dropout" , type = float , default = 0.0 , help = "Dropout probability for LoRA layers" )
370376
371377 parser .add_argument (
@@ -1078,7 +1084,7 @@ def main(args):
10781084 # now we will add new LoRA weights the transformer layers
10791085 transformer_lora_config = LoraConfig (
10801086 r = args .rank ,
1081- lora_alpha = args .rank ,
1087+ lora_alpha = args .lora_alpha ,
10821088 lora_dropout = args .lora_dropout ,
10831089 init_lora_weights = "gaussian" ,
10841090 target_modules = target_modules ,
@@ -1094,11 +1100,13 @@ def unwrap_model(model):
10941100 def save_model_hook (models , weights , output_dir ):
10951101 if accelerator .is_main_process :
10961102 transformer_lora_layers_to_save = None
1103+ modules_to_save = {}
10971104
10981105 for model in models :
10991106 if isinstance (unwrap_model (model ), type (unwrap_model (transformer ))):
11001107 model = unwrap_model (model )
11011108 transformer_lora_layers_to_save = get_peft_model_state_dict (model )
1109+ modules_to_save ["transformer" ] = model
11021110 else :
11031111 raise ValueError (f"unexpected save model: { model .__class__ } " )
11041112
@@ -1109,6 +1117,7 @@ def save_model_hook(models, weights, output_dir):
11091117 QwenImagePipeline .save_lora_weights (
11101118 output_dir ,
11111119 transformer_lora_layers = transformer_lora_layers_to_save ,
1120+ ** _collate_lora_metadata (modules_to_save ),
11121121 )
11131122
11141123 def load_model_hook (models , input_dir ):
@@ -1258,31 +1267,31 @@ def load_model_hook(models, input_dir):
12581267
12591268 def compute_text_embeddings (prompt , text_encoding_pipeline ):
12601269 with torch .no_grad ():
1261- prompt_embeds , prompt_embeds_mask , text_ids = text_encoding_pipeline .encode_prompt (
1270+ prompt_embeds , prompt_embeds_mask = text_encoding_pipeline .encode_prompt (
12621271 prompt = prompt , max_sequence_length = args .max_sequence_length
12631272 )
1264- return prompt_embeds , prompt_embeds_mask , text_ids
1273+ return prompt_embeds , prompt_embeds_mask
12651274
12661275 # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
12671276 # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
12681277 # the redundant encoding.
12691278 if not train_dataset .custom_instance_prompts :
12701279 with offload_models (text_encoding_pipeline , device = accelerator .device , offload = args .offload ):
1271- instance_prompt_embeds , instance_prompt_embeds_mask , _ = compute_text_embeddings (
1280+ instance_prompt_embeds , instance_prompt_embeds_mask = compute_text_embeddings (
12721281 args .instance_prompt , text_encoding_pipeline
12731282 )
12741283
12751284 # Handle class prompt for prior-preservation.
12761285 if args .with_prior_preservation :
12771286 with offload_models (text_encoding_pipeline , device = accelerator .device , offload = args .offload ):
1278- class_prompt_embeds , class_prompt_embeds_mask , _ = compute_text_embeddings (
1287+ class_prompt_embeds , class_prompt_embeds_mask = compute_text_embeddings (
12791288 args .class_prompt , text_encoding_pipeline
12801289 )
12811290
12821291 validation_embeddings = {}
12831292 if args .validation_prompt is not None :
12841293 with offload_models (text_encoding_pipeline , device = accelerator .device , offload = args .offload ):
1285- (validation_embeddings ["prompt_embeds" ], validation_embeddings ["prompt_embeds_mask" ], _ ) = (
1294+ (validation_embeddings ["prompt_embeds" ], validation_embeddings ["prompt_embeds_mask" ]) = (
12861295 compute_text_embeddings (args .validation_prompt , text_encoding_pipeline )
12871296 )
12881297
@@ -1314,7 +1323,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
13141323 latents_cache .append (vae .encode (batch ["pixel_values" ]).latent_dist )
13151324 if train_dataset .custom_instance_prompts :
13161325 with offload_models (text_encoding_pipeline , device = accelerator .device , offload = args .offload ):
1317- prompt_embeds , prompt_embeds_mask , _ = compute_text_embeddings (
1326+ prompt_embeds , prompt_embeds_mask = compute_text_embeddings (
13181327 batch ["prompts" ], text_encoding_pipeline
13191328 )
13201329 prompt_embeds_cache .append (prompt_embeds )
@@ -1438,8 +1447,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
14381447 prompt_embeds = prompt_embeds_cache [step ]
14391448 prompt_embeds_mask = prompt_embeds_mask_cache [step ]
14401449 else :
1441- prompt_embeds = prompt_embeds .repeat (len (prompts ), 1 , 1 )
1442- prompt_embeds_mask = prompt_embeds_mask .repeat (1 , len (prompts ), 1 , 1 )
1450+ num_repeat_elements = len (prompts )
1451+ prompt_embeds = prompt_embeds .repeat (num_repeat_elements , 1 , 1 )
1452+ prompt_embeds_mask = prompt_embeds_mask .repeat (num_repeat_elements , 1 )
14431453 # Convert images to latent space
14441454 if args .cache_latents :
14451455 model_input = latents_cache [step ].sample ()
@@ -1485,6 +1495,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
14851495 height = model_input .shape [3 ],
14861496 width = model_input .shape [4 ],
14871497 )
1498+ print (f"{ prompt_embeds_mask .sum (dim = 1 ).tolist ()= } " )
14881499 model_pred = transformer (
14891500 hidden_states = packed_noisy_model_input ,
14901501 encoder_hidden_states = prompt_embeds ,
@@ -1602,17 +1613,20 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16021613 # Save the lora layers
16031614 accelerator .wait_for_everyone ()
16041615 if accelerator .is_main_process :
1616+ modules_to_save = {}
16051617 transformer = unwrap_model (transformer )
16061618 if args .bnb_quantization_config_path is None :
16071619 if args .upcast_before_saving :
16081620 transformer .to (torch .float32 )
16091621 else :
16101622 transformer = transformer .to (weight_dtype )
16111623 transformer_lora_layers = get_peft_model_state_dict (transformer )
1624+ modules_to_save ["transformer" ] = transformer
16121625
16131626 QwenImagePipeline .save_lora_weights (
16141627 save_directory = args .output_dir ,
16151628 transformer_lora_layers = transformer_lora_layers ,
1629+ ** _collate_lora_metadata (modules_to_save ),
16161630 )
16171631
16181632 images = []
0 commit comments