@@ -590,8 +590,7 @@ def main():
590590 weight_dtype = torch .bfloat16
591591
592592 # Freeze the unet parameters before adding adapters
593- for param in unet .parameters ():
594- param .requires_grad_ (False )
593+ unet .requires_grad_ (False )
595594
596595 unet_lora_config = LoraConfig (
597596 r = args .rank ,
@@ -628,7 +627,7 @@ def main():
628627 else :
629628 raise ValueError ("xformers is not available. Make sure it is installed correctly" )
630629
631- lora_layers = filter (lambda p : p .requires_grad , unet .parameters ())
630+ trainable_params = filter (lambda p : p .requires_grad , unet .parameters ())
632631
633632 def unwrap_model (model ):
634633 model = accelerator .unwrap_model (model )
@@ -699,7 +698,7 @@ def load_model_hook(models, input_dir):
699698
700699 # train on only lora_layers
701700 optimizer = optimizer_cls (
702- lora_layers ,
701+ trainable_params ,
703702 lr = args .learning_rate ,
704703 betas = (args .adam_beta1 , args .adam_beta2 ),
705704 weight_decay = args .adam_weight_decay ,
@@ -1014,15 +1013,15 @@ def collate_fn(examples):
10141013 # Backpropagate
10151014 accelerator .backward (loss )
10161015 if accelerator .sync_gradients :
1017- accelerator .clip_grad_norm_ (lora_layers , args .max_grad_norm )
1016+ accelerator .clip_grad_norm_ (trainable_params , args .max_grad_norm )
10181017 optimizer .step ()
10191018 lr_scheduler .step ()
10201019 optimizer .zero_grad ()
10211020
10221021 # Checks if the accelerator has performed an optimization step behind the scenes
10231022 if accelerator .sync_gradients :
10241023 if args .use_ema :
1025- ema_unet .step (lora_layers )
1024+ ema_unet .step (trainable_params )
10261025 progress_bar .update (1 )
10271026 global_step += 1
10281027 accelerator .log ({"train_loss" : train_loss }, step = global_step )
0 commit comments