1414# See the License for the specific language governing permissions and 
1515# limitations under the License. 
1616
17- """Script to fine-tune Stable Diffusion for InstructPix2Pix.""" 
17+ """ 
18+     Script to fine-tune Stable Diffusion for LORA InstructPix2Pix. 
19+     Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py 
20+ """ 
1821
1922import  argparse 
2023import  logging 
3033import  PIL 
3134import  requests 
3235import  torch 
36+ import  torch .nn  as  nn 
3337import  torch .nn .functional  as  F 
3438import  torch .utils .checkpoint 
3539import  transformers 
5054from  diffusers .training_utils  import  EMAModel 
5155from  diffusers .utils  import  check_min_version , deprecate , is_wandb_available 
5256from  diffusers .utils .import_utils  import  is_xformers_available 
57+ if  is_wandb_available ():
58+ 
59+     import  wandb 
5360
5461
5562# Will error if the minimal version of diffusers is not installed. Remove at your own risks. 
56- check_min_version ("0.26 .0.dev0" )
63+ check_min_version ("0.32 .0.dev0" )
5764
5865logger  =  get_logger (__name__ , log_level = "INFO" )
5966
6269}
6370WANDB_TABLE_COL_NAMES  =  ["original_image" , "edited_image" , "edit_prompt" ]
6471
72+ def  log_validation (
73+     pipeline ,
74+     args ,
75+     accelerator ,
76+     generator ,
77+ ):
78+     logger .info (
79+         f"Running validation... \n  Generating { args .num_validation_images }  
80+         f" { args .validation_prompt }  
81+     )
82+     pipeline  =  pipeline .to (accelerator .device )
83+     pipeline .set_progress_bar_config (disable = True )
84+ 
85+     # run inference 
86+     original_image  =  download_image (args .val_image_url )
87+     edited_images  =  []
88+     if  torch .backends .mps .is_available ():
89+         autocast_ctx  =  nullcontext ()
90+     else :
91+         autocast_ctx  =  torch .autocast (accelerator .device .type )
92+ 
93+     with  autocast_ctx :
94+         for  _  in  range (args .num_validation_images ):
95+             edited_images .append (
96+                 pipeline (
97+                     args .validation_prompt ,
98+                     image = original_image ,
99+                     num_inference_steps = 20 ,
100+                     image_guidance_scale = 1.5 ,
101+                     guidance_scale = 7 ,
102+                     generator = generator ,
103+                 ).images [0 ]
104+             )
105+ 
106+     for  tracker  in  accelerator .trackers :
107+         if  tracker .name  ==  "wandb" :
108+             wandb_table  =  wandb .Table (columns = WANDB_TABLE_COL_NAMES )
109+             for  edited_image  in  edited_images :
110+                 wandb_table .add_data (wandb .Image (original_image ), wandb .Image (edited_image ), args .validation_prompt )
111+             tracker .log ({"validation" : wandb_table })
112+     
113+     return  edited_images 
65114
66115def  parse_args ():
67116    parser  =  argparse .ArgumentParser (description = "Simple example of a training script for InstructPix2Pix." )
@@ -417,11 +466,6 @@ def main():
417466
418467    generator  =  torch .Generator (device = accelerator .device ).manual_seed (args .seed )
419468
420-     if  args .report_to  ==  "wandb" :
421-         if  not  is_wandb_available ():
422-             raise  ImportError ("Make sure to install wandb if you want to use it for logging during training." )
423-         import  wandb 
424- 
425469    # Make one log on every process with the configuration for debugging. 
426470    logging .basicConfig (
427471        format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" ,
@@ -467,6 +511,24 @@ def main():
467511        args .pretrained_model_name_or_path , subfolder = "unet" , revision = args .non_ema_revision 
468512    )
469513
514+     # InstructPix2Pix uses an additional image for conditioning. To accommodate that, 
515+     # it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is 
516+     # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized 
517+     # from the pre-trained checkpoints. For the extra channels added to the first layer, they are 
518+     # initialized to zero. 
519+     logger .info ("Initializing the InstructPix2Pix UNet from the pretrained UNet." )
520+     in_channels  =  8 
521+     out_channels  =  unet .conv_in .out_channels 
522+     unet .register_to_config (in_channels = in_channels )
523+ 
524+     with  torch .no_grad ():
525+         new_conv_in  =  nn .Conv2d (
526+             in_channels , out_channels , unet .conv_in .kernel_size , unet .conv_in .stride , unet .conv_in .padding 
527+         )
528+         new_conv_in .weight .zero_ ()
529+         new_conv_in .weight [:, :in_channels , :, :].copy_ (unet .conv_in .weight )
530+         unet .conv_in  =  new_conv_in 
531+ 
470532    # Freeze vae, text_encoder and unet 
471533    vae .requires_grad_ (False )
472534    text_encoder .requires_grad_ (False )
@@ -528,6 +590,11 @@ def main():
528590        else :
529591            raise  ValueError ("xformers is not available. Make sure it is installed correctly" )
530592
593+     def  unwrap_model (model ):
594+         model  =  accelerator .unwrap_model (model )
595+         model  =  model ._orig_mod  if  is_compiled_module (model ) else  model 
596+         return  model 
597+ 
531598    # `accelerate` 0.16.0 will have better support for customized saving 
532599    if  version .parse (accelerate .__version__ ) >=  version .parse ("0.16.0" ):
533600        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format 
@@ -540,7 +607,8 @@ def save_model_hook(models, weights, output_dir):
540607                    model .save_pretrained (os .path .join (output_dir , "unet" ))
541608
542609                    # make sure to pop weight so that corresponding model is not saved again 
543-                     weights .pop ()
610+                     if  weights :
611+                         weights .pop ()
544612
545613        def  load_model_hook (models , input_dir ):
546614            if  args .use_ema :
@@ -730,17 +798,22 @@ def collate_fn(examples):
730798    )
731799
732800    # Scheduler and math around the number of training steps. 
733-     overrode_max_train_steps   =   False 
734-     num_update_steps_per_epoch  =  math . ceil ( len ( train_dataloader )  /   args . gradient_accumulation_steps ) 
801+     # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. 
802+     num_warmup_steps_for_scheduler  =  args . lr_warmup_steps   *   accelerator . num_processes 
735803    if  args .max_train_steps  is  None :
736-         args .max_train_steps  =  args .num_train_epochs  *  num_update_steps_per_epoch 
737-         overrode_max_train_steps  =  True 
804+         len_train_dataloader_after_sharding  =  math .ceil (len (train_dataloader ) /  accelerator .num_processes )
805+         num_update_steps_per_epoch  =  math .ceil (len_train_dataloader_after_sharding  /  args .gradient_accumulation_steps )
806+         num_training_steps_for_scheduler  =  (
807+             args .num_train_epochs  *  num_update_steps_per_epoch  *  accelerator .num_processes 
808+         )
809+     else :
810+         num_training_steps_for_scheduler  =  args .max_train_steps  *  accelerator .num_processes 
738811
739812    lr_scheduler  =  get_scheduler (
740813        args .lr_scheduler ,
741814        optimizer = optimizer ,
742-         num_warmup_steps = args . lr_warmup_steps   *   accelerator . num_processes ,
743-         num_training_steps = args . max_train_steps   *   accelerator . num_processes ,
815+         num_warmup_steps = num_warmup_steps_for_scheduler ,
816+         num_training_steps = num_training_steps_for_scheduler ,
744817    )
745818
746819    # Prepare everything with our `accelerator`. 
@@ -765,8 +838,14 @@ def collate_fn(examples):
765838
766839    # We need to recalculate our total training steps as the size of the training dataloader may have changed. 
767840    num_update_steps_per_epoch  =  math .ceil (len (train_dataloader ) /  args .gradient_accumulation_steps )
768-     if  overrode_max_train_steps :
841+     if  args . max_train_steps   is   None :
769842        args .max_train_steps  =  args .num_train_epochs  *  num_update_steps_per_epoch 
843+         if  num_training_steps_for_scheduler  !=  args .max_train_steps  *  accelerator .num_processes :
844+             logger .warning (
845+                 f"The length of the 'train_dataloader' after 'accelerator.prepare' ({ len (train_dataloader )}  
846+                 f"the expected length ({ len_train_dataloader_after_sharding }  
847+                 f"This inconsistency may result in the learning rate scheduler not functioning properly." 
848+             )
770849    # Afterwards we recalculate our number of training epochs 
771850    args .num_train_epochs  =  math .ceil (args .max_train_steps  /  num_update_steps_per_epoch )
772851
@@ -959,45 +1038,22 @@ def collate_fn(examples):
9591038                # The models need unwrapping because for compatibility in distributed training mode. 
9601039                pipeline  =  StableDiffusionInstructPix2PixPipeline .from_pretrained (
9611040                    args .pretrained_model_name_or_path ,
962-                     unet = accelerator . unwrap_model (unet ),
963-                     text_encoder = accelerator . unwrap_model (text_encoder ),
964-                     vae = accelerator . unwrap_model (vae ),
1041+                     unet = unwrap_model (unet ),
1042+                     text_encoder = unwrap_model (text_encoder ),
1043+                     vae = unwrap_model (vae ),
9651044                    revision = args .revision ,
9661045                    variant = args .variant ,
9671046                    torch_dtype = weight_dtype ,
9681047                )
969-                 pipeline  =  pipeline .to (accelerator .device )
970-                 pipeline .set_progress_bar_config (disable = True )
9711048
9721049                # run inference 
973-                 original_image  =  download_image (args .val_image_url )
974-                 edited_images  =  []
975-                 if  torch .backends .mps .is_available ():
976-                     autocast_ctx  =  nullcontext ()
977-                 else :
978-                     autocast_ctx  =  torch .autocast (accelerator .device .type )
979- 
980-                 with  autocast_ctx :
981-                     for  _  in  range (args .num_validation_images ):
982-                         edited_images .append (
983-                             pipeline (
984-                                 args .validation_prompt ,
985-                                 image = original_image ,
986-                                 num_inference_steps = 20 ,
987-                                 image_guidance_scale = 1.5 ,
988-                                 guidance_scale = 7 ,
989-                                 generator = generator ,
990-                             ).images [0 ]
991-                         )
992- 
993-                 for  tracker  in  accelerator .trackers :
994-                     if  tracker .name  ==  "wandb" :
995-                         wandb_table  =  wandb .Table (columns = WANDB_TABLE_COL_NAMES )
996-                         for  edited_image  in  edited_images :
997-                             wandb_table .add_data (
998-                                 wandb .Image (original_image ), wandb .Image (edited_image ), args .validation_prompt 
999-                             )
1000-                         tracker .log ({"validation" : wandb_table })
1050+                 log_validation (
1051+                     pipeline ,
1052+                     args ,
1053+                     accelerator ,
1054+                     generator ,
1055+                 )
1056+ 
10011057                if  args .use_ema :
10021058                    # Switch back to the original UNet parameters. 
10031059                    ema_unet .restore (unet .parameters ())
@@ -1014,9 +1070,9 @@ def collate_fn(examples):
10141070
10151071        pipeline  =  StableDiffusionInstructPix2PixPipeline .from_pretrained (
10161072            args .pretrained_model_name_or_path ,
1017-             text_encoder = accelerator . unwrap_model (text_encoder ),
1018-             vae = accelerator . unwrap_model (vae ),
1019-             unet = unet ,
1073+             text_encoder = unwrap_model (text_encoder ),
1074+             vae = unwrap_model (vae ),
1075+             unet = unwrap_model ( unet ) ,
10201076            revision = args .revision ,
10211077            variant = args .variant ,
10221078        )
@@ -1031,31 +1087,6 @@ def collate_fn(examples):
10311087                ignore_patterns = ["step_*" , "epoch_*" ],
10321088            )
10331089
1034-         if  args .validation_prompt  is  not None :
1035-             edited_images  =  []
1036-             pipeline  =  pipeline .to (accelerator .device )
1037-             with  torch .autocast (str (accelerator .device ).replace (":0" , "" )):
1038-                 for  _  in  range (args .num_validation_images ):
1039-                     edited_images .append (
1040-                         pipeline (
1041-                             args .validation_prompt ,
1042-                             image = original_image ,
1043-                             num_inference_steps = 20 ,
1044-                             image_guidance_scale = 1.5 ,
1045-                             guidance_scale = 7 ,
1046-                             generator = generator ,
1047-                         ).images [0 ]
1048-                     )
1049- 
1050-             for  tracker  in  accelerator .trackers :
1051-                 if  tracker .name  ==  "wandb" :
1052-                     wandb_table  =  wandb .Table (columns = WANDB_TABLE_COL_NAMES )
1053-                     for  edited_image  in  edited_images :
1054-                         wandb_table .add_data (
1055-                             wandb .Image (original_image ), wandb .Image (edited_image ), args .validation_prompt 
1056-                         )
1057-                     tracker .log ({"test" : wandb_table })
1058- 
10591090    accelerator .end_training ()
10601091
10611092
0 commit comments