5454)
5555from diffusers .optimization import get_scheduler
5656from diffusers .training_utils import (
57+ _collate_lora_metadata ,
5758 cast_training_params ,
5859 compute_density_for_timestep_sampling ,
5960 compute_loss_weighting_for_sd3 ,
@@ -420,6 +421,13 @@ def parse_args(input_args=None):
420421
421422 parser .add_argument ("--lora_dropout" , type = float , default = 0.0 , help = "Dropout probability for LoRA layers" )
422423
424+ parser .add_argument (
425+ "--lora_alpha" ,
426+ type = int ,
427+ default = 4 ,
428+ help = "LoRA alpha to be used for additional scaling." ,
429+ )
430+
423431 parser .add_argument (
424432 "--with_prior_preservation" ,
425433 default = False ,
@@ -1163,7 +1171,7 @@ def main(args):
11631171 # now we will add new LoRA weights the transformer layers
11641172 transformer_lora_config = LoraConfig (
11651173 r = args .rank ,
1166- lora_alpha = args .rank ,
1174+ lora_alpha = args .lora_alpha ,
11671175 lora_dropout = args .lora_dropout ,
11681176 init_lora_weights = "gaussian" ,
11691177 target_modules = target_modules ,
@@ -1180,10 +1188,12 @@ def save_model_hook(models, weights, output_dir):
11801188 if accelerator .is_main_process :
11811189 transformer_lora_layers_to_save = None
11821190
1191+ modules_to_save = {}
11831192 for model in models :
11841193 if isinstance (unwrap_model (model ), type (unwrap_model (transformer ))):
11851194 model = unwrap_model (model )
11861195 transformer_lora_layers_to_save = get_peft_model_state_dict (model )
1196+ modules_to_save ["transformer" ] = model
11871197 else :
11881198 raise ValueError (f"unexpected save model: { model .__class__ } " )
11891199
@@ -1194,6 +1204,7 @@ def save_model_hook(models, weights, output_dir):
11941204 HiDreamImagePipeline .save_lora_weights (
11951205 output_dir ,
11961206 transformer_lora_layers = transformer_lora_layers_to_save ,
1207+ ** _collate_lora_metadata (modules_to_save ),
11971208 )
11981209
11991210 def load_model_hook (models , input_dir ):
@@ -1496,6 +1507,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
14961507 # We need to initialize the trackers we use, and also store our configuration.
14971508 # The trackers initializes automatically on the main process.
14981509 if accelerator .is_main_process :
1510+ modules_to_save = {}
14991511 tracker_name = "dreambooth-hidream-lora"
15001512 accelerator .init_trackers (tracker_name , config = vars (args ))
15011513
@@ -1737,6 +1749,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17371749 else :
17381750 transformer = transformer .to (weight_dtype )
17391751 transformer_lora_layers = get_peft_model_state_dict (transformer )
1752+ modules_to_save ["transformer" ] = transformer
17401753
17411754 HiDreamImagePipeline .save_lora_weights (
17421755 save_directory = args .output_dir ,
0 commit comments