1818import  logging 
1919import  math 
2020import  os 
21+ import  random 
2122import  shutil 
2223from  contextlib  import  nullcontext 
2324from  pathlib  import  Path 
@@ -76,13 +77,16 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
7677        pipeline  =  FluxControlPipeline .from_pretrained (
7778            args .pretrained_model_name_or_path ,
7879            transformer = flux_transformer ,
79-             torch_dtype = torch . bfloat16 ,
80+             torch_dtype = weight_dtype ,
8081        )
8182    else :
83+         transformer  =  FluxTransformer2DModel .from_pretrained (
84+             args .pretrained_model_name_or_path , subfolder = "transformer" , torch_dtype = weight_dtype 
85+         )
8286        pipeline  =  FluxControlPipeline .from_pretrained (
8387            args .pretrained_model_name_or_path ,
84-             transformer = flux_transformer ,
85-             torch_dtype = torch . bfloat16 ,
88+             transformer = transformer ,
89+             torch_dtype = weight_dtype ,
8690        )
8791        pipeline .load_lora_weights (args .output_dir )
8892
@@ -307,6 +311,12 @@ def parse_args(input_args=None):
307311        default = 4 ,
308312        help = ("The dimension of the LoRA update matrices." ),
309313    )
314+     parser .add_argument (
315+         "--proportion_empty_prompts" ,
316+         type = float ,
317+         default = 0 ,
318+         help = "Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement)." ,
319+     )
310320    parser .add_argument (
311321        "--lora_layers" ,
312322        type = str ,
@@ -474,12 +484,6 @@ def parse_args(input_args=None):
474484            "value if set." 
475485        ),
476486    )
477-     parser .add_argument (
478-         "--proportion_empty_prompts" ,
479-         type = float ,
480-         default = 0 ,
481-         help = "Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement)." ,
482-     )
483487    parser .add_argument (
484488        "--validation_prompt" ,
485489        type = str ,
@@ -864,13 +868,15 @@ def save_model_hook(models, weights, output_dir):
864868                transformer_lora_layers_to_save  =  None 
865869
866870                for  model  in  models :
867-                     if  isinstance (model , type (unwrap_model (flux_transformer ))):
871+                     if  isinstance (unwrap_model (model ), type (unwrap_model (flux_transformer ))):
872+                         model  =  unwrap_model (model )
868873                        transformer_lora_layers_to_save  =  get_peft_model_state_dict (model )
869874                    else :
870875                        raise  ValueError (f"unexpected save model: { model .__class__ }  )
871876
872877                    # make sure to pop weight so that corresponding model is not saved again 
873-                     weights .pop ()
878+                     if  weights :
879+                         weights .pop ()
874880
875881                FluxControlPipeline .save_lora_weights (
876882                    output_dir ,
@@ -880,16 +886,22 @@ def save_model_hook(models, weights, output_dir):
880886        def  load_model_hook (models , input_dir ):
881887            transformer_  =  None 
882888
883-             while  len (models ) >  0 :
884-                 model  =  models .pop ()
889+             if  not  accelerator .distributed_type  ==  DistributedType .DEEPSPEED :
890+                 while  len (models ) >  0 :
891+                     model  =  models .pop ()
885892
886-                 if  isinstance (model , type (unwrap_model (flux_transformer ))):
887-                     transformer_  =  model 
888-                 else :
889-                     raise  ValueError (f"unexpected save model: { model .__class__ }  )
893+                      if  isinstance (model , type (unwrap_model (flux_transformer ))):
894+                          transformer_  =  model 
895+                      else :
896+                          raise  ValueError (f"unexpected save model: { model .__class__ }  )
890897
891-             lora_state_dict  =  FluxControlPipeline .lora_state_dict (input_dir )
898+             else :
899+                 transformer_  =  FluxTransformer2DModel .from_pretrained (
900+                     args .pretrained_model_name_or_path , subfolder = "transformer" 
901+                 ).to (accelerator .device , weight_dtype )
902+                 transformer_ .add_adapter (transformer_lora_config )
892903
904+             lora_state_dict  =  FluxControlPipeline .lora_state_dict (input_dir )
893905            transformer_state_dict  =  {
894906                f'{ k .replace ("transformer." , "" )}  : v 
895907                for  k , v  in  lora_state_dict .items ()
@@ -1135,7 +1147,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
11351147                )
11361148
11371149                # handle guidance 
1138-                 if  flux_transformer .config .guidance_embeds :
1150+                 if  unwrap_model ( flux_transformer ) .config .guidance_embeds :
11391151                    guidance_vec  =  torch .full (
11401152                        (bsz ,),
11411153                        args .guidance_scale ,
@@ -1152,7 +1164,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
11521164                    prompt_embeds , pooled_prompt_embeds , text_ids  =  text_encoding_pipeline .encode_prompt (
11531165                        captions , prompt_2 = None 
11541166                    )
1155-                 text_encoding_pipeline  =  text_encoding_pipeline .to ("cuda" )
1167+                 # this could be optimized by not having to do any text encoding and just 
1168+                 # doing zeros on specified shapes for `prompt_embeds` and `pooled_prompt_embeds` 
1169+                 if  args .proportion_empty_prompts  and  random .random () <  args .proportion_empty_prompts :
1170+                     prompt_embeds .zero_ ()
1171+                     pooled_prompt_embeds .zero_ ()
1172+                 text_encoding_pipeline  =  text_encoding_pipeline .to ("cpu" )
11561173
11571174                # Predict. 
11581175                model_pred  =  flux_transformer (
@@ -1274,7 +1291,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
12741291                repo_id = repo_id ,
12751292                folder_path = args .output_dir ,
12761293                commit_message = "End of training" ,
1277-                 ignore_patterns = ["step_*" , "epoch_*" ],
1294+                 ignore_patterns = ["step_*" , "epoch_*" ,  "*.pt" ,  "*.bin" ],
12781295            )
12791296
12801297    accelerator .end_training ()
0 commit comments