5252)
5353from diffusers .optimization import get_scheduler
5454from diffusers .training_utils import (
55+ _collate_lora_metadata ,
5556 cast_training_params ,
5657 compute_density_for_timestep_sampling ,
5758 compute_loss_weighting_for_sd3 ,
@@ -323,9 +324,13 @@ def parse_args(input_args=None):
323324 default = 4 ,
324325 help = ("The dimension of the LoRA update matrices." ),
325326 )
326-
327+ parser .add_argument (
328+ "--lora_alpha" ,
329+ type = int ,
330+ default = 4 ,
331+ help = "LoRA alpha to be used for additional scaling." ,
332+ )
327333 parser .add_argument ("--lora_dropout" , type = float , default = 0.0 , help = "Dropout probability for LoRA layers" )
328-
329334 parser .add_argument (
330335 "--with_prior_preservation" ,
331336 default = False ,
@@ -1023,7 +1028,7 @@ def main(args):
10231028 # now we will add new LoRA weights the transformer layers
10241029 transformer_lora_config = LoraConfig (
10251030 r = args .rank ,
1026- lora_alpha = args .rank ,
1031+ lora_alpha = args .lora_alpha ,
10271032 lora_dropout = args .lora_dropout ,
10281033 init_lora_weights = "gaussian" ,
10291034 target_modules = target_modules ,
@@ -1039,10 +1044,11 @@ def unwrap_model(model):
10391044 def save_model_hook (models , weights , output_dir ):
10401045 if accelerator .is_main_process :
10411046 transformer_lora_layers_to_save = None
1042-
1047+ modules_to_save = {}
10431048 for model in models :
10441049 if isinstance (model , type (unwrap_model (transformer ))):
10451050 transformer_lora_layers_to_save = get_peft_model_state_dict (model )
1051+ modules_to_save ["transformer" ] = model
10461052 else :
10471053 raise ValueError (f"unexpected save model: { model .__class__ } " )
10481054
@@ -1052,6 +1058,7 @@ def save_model_hook(models, weights, output_dir):
10521058 SanaPipeline .save_lora_weights (
10531059 output_dir ,
10541060 transformer_lora_layers = transformer_lora_layers_to_save ,
1061+ ** _collate_lora_metadata (modules_to_save ),
10551062 )
10561063
10571064 def load_model_hook (models , input_dir ):
@@ -1507,15 +1514,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15071514 accelerator .wait_for_everyone ()
15081515 if accelerator .is_main_process :
15091516 transformer = unwrap_model (transformer )
1517+ modules_to_save = {}
15101518 if args .upcast_before_saving :
15111519 transformer .to (torch .float32 )
15121520 else :
15131521 transformer = transformer .to (weight_dtype )
15141522 transformer_lora_layers = get_peft_model_state_dict (transformer )
1523+ modules_to_save ["transformer" ] = transformer
15151524
15161525 SanaPipeline .save_lora_weights (
15171526 save_directory = args .output_dir ,
15181527 transformer_lora_layers = transformer_lora_layers ,
1528+ ** _collate_lora_metadata (modules_to_save ),
15191529 )
15201530
15211531 # Final inference
0 commit comments