4343from  datasets  import  load_dataset 
4444from  huggingface_hub  import  create_repo , upload_folder 
4545from  packaging  import  version 
46+ from  peft  import  LoraConfig 
47+ from  peft .utils  import  get_peft_model_state_dict 
4648from  torchvision  import  transforms 
4749from  tqdm .auto  import  tqdm 
4850from  transformers  import  CLIPTextModel , CLIPTokenizer 
5153from  diffusers  import  AutoencoderKL , DDPMScheduler , StableDiffusionInstructPix2PixPipeline , UNet2DConditionModel 
5254from  diffusers .models .lora  import  LoRALinearLayer 
5355from  diffusers .optimization  import  get_scheduler 
54- from  diffusers .training_utils  import  EMAModel 
55- from  diffusers .utils  import  check_min_version , deprecate , is_wandb_available 
56+ from  diffusers .training_utils  import  cast_training_params , EMAModel 
57+ from  diffusers .utils  import  check_min_version , deprecate , convert_state_dict_to_diffusers , is_wandb_available 
58+ from  diffusers .utils .hub_utils  import  load_or_create_model_card , populate_model_card 
5659from  diffusers .utils .import_utils  import  is_xformers_available 
60+ from  diffusers .utils .torch_utils  import  is_compiled_module 
61+ 
5762if  is_wandb_available ():
5863
5964    import  wandb 
6974}
7075WANDB_TABLE_COL_NAMES  =  ["original_image" , "edited_image" , "edit_prompt" ]
7176
77+ def  save_model_card (
78+     repo_id : str ,
79+     images : list  =  None ,
80+     base_model : str  =  None ,
81+     dataset_name : str  =  None ,
82+     repo_folder : str  =  None ,
83+ ):
84+     img_str  =  "" 
85+     if  images  is  not None :
86+         for  i , image  in  enumerate (images ):
87+             image .save (os .path .join (repo_folder , f"image_{ i }  ))
88+             img_str  +=  f"![img_{ i } { i } \n " 
89+ 
90+     model_description  =  f""" 
91+ # LoRA text2image fine-tuning - { repo_id }  
92+ These are LoRA adaption weights for { base_model } { dataset_name } \n  
93+ { img_str } 
94+ """ 
95+ 
96+     model_card  =  load_or_create_model_card (
97+         repo_id_or_path = repo_id ,
98+         from_training = True ,
99+         license = "creativeml-openrail-m" ,
100+         base_model = base_model ,
101+         model_description = model_description ,
102+         inference = True ,
103+     )
104+ 
105+     tags  =  [
106+         "stable-diffusion" ,
107+         "stable-diffusion-diffusers" ,
108+         "text-to-image" ,
109+         "diffusers" ,
110+         "diffusers-training" ,
111+         "lora" ,
112+     ]
113+     model_card  =  populate_model_card (model_card , tags = tags )
114+ 
115+     model_card .save (os .path .join (repo_folder , "README.md" ))
116+ 
117+ 
72118def  log_validation (
73119    pipeline ,
74120    args ,
@@ -535,43 +581,35 @@ def main():
535581    unet .requires_grad_ (False )
536582
537583    # referred to https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py 
538-     unet_lora_parameters  =  []
539-     for  attn_processor_name , attn_processor  in  unet .attn_processors .items ():
540-         # Parse the attention module. 
541-         attn_module  =  unet 
542-         for  n  in  attn_processor_name .split ("." )[:- 1 ]:
543-             attn_module  =  getattr (attn_module , n )
544- 
545-         # Set the `lora_layer` attribute of the attention-related matrices. 
546-         attn_module .to_q .set_lora_layer (
547-             LoRALinearLayer (
548-                 in_features = attn_module .to_q .in_features , out_features = attn_module .to_q .out_features , rank = args .rank 
549-             )
550-         )
551-         attn_module .to_k .set_lora_layer (
552-             LoRALinearLayer (
553-                 in_features = attn_module .to_k .in_features , out_features = attn_module .to_k .out_features , rank = args .rank 
554-             )
555-         )
584+     # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision 
585+     # as these weights are only used for inference, keeping weights in full precision is not required. 
586+     weight_dtype  =  torch .float32 
587+     if  accelerator .mixed_precision  ==  "fp16" :
588+         weight_dtype  =  torch .float16 
589+     elif  accelerator .mixed_precision  ==  "bf16" :
590+         weight_dtype  =  torch .bfloat16 
556591
557-         attn_module .to_v .set_lora_layer (
558-             LoRALinearLayer (
559-                 in_features = attn_module .to_v .in_features , out_features = attn_module .to_v .out_features , rank = args .rank 
560-             )
561-         )
562-         attn_module .to_out [0 ].set_lora_layer (
563-             LoRALinearLayer (
564-                 in_features = attn_module .to_out [0 ].in_features ,
565-                 out_features = attn_module .to_out [0 ].out_features ,
566-                 rank = args .rank ,
567-             )
568-         )
592+     # Freeze the unet parameters before adding adapters 
593+     for  param  in  unet .parameters ():
594+         param .requires_grad_ (False )
569595
570-         # Accumulate the LoRA params to optimize. 
571-         unet_lora_parameters .extend (attn_module .to_q .lora_layer .parameters ())
572-         unet_lora_parameters .extend (attn_module .to_k .lora_layer .parameters ())
573-         unet_lora_parameters .extend (attn_module .to_v .lora_layer .parameters ())
574-         unet_lora_parameters .extend (attn_module .to_out [0 ].lora_layer .parameters ())
596+     unet_lora_config  =  LoraConfig (
597+         r = args .rank ,
598+         lora_alpha = args .rank ,
599+         init_lora_weights = "gaussian" ,
600+         target_modules = ["to_k" , "to_q" , "to_v" , "to_out.0" ],
601+     )
602+ 
603+     # Move unet, vae and text_encoder to device and cast to weight_dtype 
604+     unet .to (accelerator .device , dtype = weight_dtype )
605+     vae .to (accelerator .device , dtype = weight_dtype )
606+     text_encoder .to (accelerator .device , dtype = weight_dtype )
607+ 
608+     # Add adapter and make sure the trainable params are in float32. 
609+     unet .add_adapter (unet_lora_config )
610+     if  args .mixed_precision  ==  "fp16" :
611+         # only upcast trainable parameters (LoRA) into fp32 
612+         cast_training_params (unet , dtype = torch .float32 )
575613
576614    # Create EMA for the unet. 
577615    if  args .use_ema :
@@ -590,6 +628,8 @@ def main():
590628        else :
591629            raise  ValueError ("xformers is not available. Make sure it is installed correctly" )
592630
631+     lora_layers  =  filter (lambda  p : p .requires_grad , unet .parameters ())
632+ 
593633    def  unwrap_model (model ):
594634        model  =  accelerator .unwrap_model (model )
595635        model  =  model ._orig_mod  if  is_compiled_module (model ) else  model 
@@ -657,9 +697,9 @@ def load_model_hook(models, input_dir):
657697    else :
658698        optimizer_cls  =  torch .optim .AdamW 
659699
660-     # train on only unet_lora_parameters  
700+     # train on only lora_layers  
661701    optimizer  =  optimizer_cls (
662-         unet_lora_parameters ,
702+         lora_layers ,
663703        lr = args .learning_rate ,
664704        betas = (args .adam_beta1 , args .adam_beta2 ),
665705        weight_decay = args .adam_weight_decay ,
@@ -817,8 +857,8 @@ def collate_fn(examples):
817857    )
818858
819859    # Prepare everything with our `accelerator`. 
820-     unet , unet_lora_parameters ,  optimizer , train_dataloader , lr_scheduler  =  accelerator .prepare (
821-         unet , unet_lora_parameters ,  optimizer , train_dataloader , lr_scheduler 
860+     unet , optimizer , train_dataloader , lr_scheduler  =  accelerator .prepare (
861+         unet , optimizer , train_dataloader , lr_scheduler 
822862    )
823863
824864    if  args .use_ema :
@@ -964,7 +1004,7 @@ def collate_fn(examples):
9641004                    raise  ValueError (f"Unknown prediction type { noise_scheduler .config .prediction_type }  )
9651005
9661006                # Predict the noise residual and compute loss 
967-                 model_pred  =  unet (concatenated_noisy_latents , timesteps , encoder_hidden_states ). sample 
1007+                 model_pred  =  unet (concatenated_noisy_latents , timesteps , encoder_hidden_states ,  return_dict = False )[ 0 ] 
9681008                loss  =  F .mse_loss (model_pred .float (), target .float (), reduction = "mean" )
9691009
9701010                # Gather the losses across all processes for logging (if we use distributed training). 
@@ -974,15 +1014,15 @@ def collate_fn(examples):
9741014                # Backpropagate 
9751015                accelerator .backward (loss )
9761016                if  accelerator .sync_gradients :
977-                     accelerator .clip_grad_norm_ (unet_lora_parameters , args .max_grad_norm )
1017+                     accelerator .clip_grad_norm_ (lora_layers , args .max_grad_norm )
9781018                optimizer .step ()
9791019                lr_scheduler .step ()
9801020                optimizer .zero_grad ()
9811021
9821022            # Checks if the accelerator has performed an optimization step behind the scenes 
9831023            if  accelerator .sync_gradients :
9841024                if  args .use_ema :
985-                     ema_unet .step (unet_lora_parameters )
1025+                     ema_unet .step (lora_layers )
9861026                progress_bar .update (1 )
9871027                global_step  +=  1 
9881028                accelerator .log ({"train_loss" : train_loss }, step = global_step )
@@ -1012,6 +1052,16 @@ def collate_fn(examples):
10121052
10131053                        save_path  =  os .path .join (args .output_dir , f"checkpoint-{ global_step }  )
10141054                        accelerator .save_state (save_path )
1055+                         unwrapped_unet  =  unwrap_model (unet )
1056+                         unet_lora_state_dict  =  convert_state_dict_to_diffusers (
1057+                             get_peft_model_state_dict (unwrapped_unet )
1058+                         )
1059+ 
1060+                         StableDiffusionInstructPix2PixPipeline .save_lora_weights (
1061+                             save_directory = save_path ,
1062+                             unet_lora_layers = unet_lora_state_dict ,
1063+                             safe_serialization = True ,
1064+                         )
10151065                        logger .info (f"Saved state to { save_path }  )
10161066
10171067            logs  =  {"step_loss" : loss .detach ().item (), "lr" : lr_scheduler .get_last_lr ()[0 ]}
@@ -1064,10 +1114,20 @@ def collate_fn(examples):
10641114    # Create the pipeline using the trained modules and save it. 
10651115    accelerator .wait_for_everyone ()
10661116    if  accelerator .is_main_process :
1067-         unet  =  accelerator .unwrap_model (unet )
10681117        if  args .use_ema :
10691118            ema_unet .copy_to (unet .parameters ())
10701119
1120+         # store only LORA layers 
1121+         unet  =  unet .to (torch .float32 )
1122+ 
1123+         unwrapped_unet  =  unwrap_model (unet )
1124+         unet_lora_state_dict  =  convert_state_dict_to_diffusers (get_peft_model_state_dict (unwrapped_unet ))
1125+         StableDiffusionInstructPix2PixPipeline .save_lora_weights (
1126+             save_directory = args .output_dir ,
1127+             unet_lora_layers = unet_lora_state_dict ,
1128+             safe_serialization = True ,
1129+         )
1130+ 
10711131        pipeline  =  StableDiffusionInstructPix2PixPipeline .from_pretrained (
10721132            args .pretrained_model_name_or_path ,
10731133            text_encoder = unwrap_model (text_encoder ),
@@ -1076,10 +1136,25 @@ def collate_fn(examples):
10761136            revision = args .revision ,
10771137            variant = args .variant ,
10781138        )
1079-         # store only LORA layers 
1080-         unet .save_attn_procs (args .output_dir )
1139+         pipeline .load_lora_weights (args .output_dir )
1140+         
1141+         images  =   None 
1142+         if  (args .val_image_url  is  not None ) and  (args .validation_prompt  is  not None ):
1143+             images  =  log_validation (
1144+                 pipeline ,
1145+                 args ,
1146+                 accelerator ,
1147+                 generator ,
1148+             )
10811149
10821150        if  args .push_to_hub :
1151+             save_model_card (
1152+                 repo_id ,
1153+                 images = images ,
1154+                 base_model = args .pretrained_model_name_or_path ,
1155+                 dataset_name = args .dataset_name ,
1156+                 repo_folder = args .output_dir ,
1157+             )
10831158            upload_folder (
10841159                repo_id = repo_id ,
10851160                folder_path = args .output_dir ,
0 commit comments