@@ -122,7 +122,6 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
122122
123123        for  _  in  range (args .num_validation_images ):
124124            with  autocast_ctx :
125-                 # need to fix in pipeline_flux_controlnet 
126125                image  =  pipeline (
127126                    prompt = validation_prompt ,
128127                    control_image = validation_image ,
@@ -159,7 +158,7 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
159158                images  =  log ["images" ]
160159                validation_prompt  =  log ["validation_prompt" ]
161160                validation_image  =  log ["validation_image" ]
162-                 formatted_images .append (wandb .Image (validation_image , caption = "Controlnet conditioning " ))
161+                 formatted_images .append (wandb .Image (validation_image , caption = "Conditioning " ))
163162                for  image  in  images :
164163                    image  =  wandb .Image (image , caption = validation_prompt )
165164                    formatted_images .append (image )
@@ -188,7 +187,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
188187            img_str  +=  f"![images_{ i } { i } \n " 
189188
190189    model_description  =  f""" 
191- # control-lora -{ repo_id }  
190+ # flux-control -{ repo_id }  
192191
193192These are Control weights trained on { base_model }  
194193{ img_str } 
@@ -434,14 +433,15 @@ def parse_args(input_args=None):
434433        "--conditioning_image_column" ,
435434        type = str ,
436435        default = "conditioning_image" ,
437-         help = "The column of the dataset containing the controlnet  conditioning image." ,
436+         help = "The column of the dataset containing the control  conditioning image." ,
438437    )
439438    parser .add_argument (
440439        "--caption_column" ,
441440        type = str ,
442441        default = "text" ,
443442        help = "The column of the dataset containing a caption or a list of captions." ,
444443    )
444+     parser .add_argument ("--log_dataset_samples" , action = "store_true" , help = "Whether to log somple dataset samples." )
445445    parser .add_argument (
446446        "--max_train_samples" ,
447447        type = int ,
@@ -468,7 +468,7 @@ def parse_args(input_args=None):
468468        default = None ,
469469        nargs = "+" ,
470470        help = (
471-             "A set of paths to the controlnet  conditioning image be evaluated every `--validation_steps`" 
471+             "A set of paths to the control  conditioning image be evaluated every `--validation_steps`" 
472472            " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" 
473473            " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" 
474474            " `--validation_image` that will be used with all `--validation_prompt`s." 
@@ -505,7 +505,11 @@ def parse_args(input_args=None):
505505        default = None ,
506506        help = "Path to the jsonl file containing the training data." ,
507507    )
508- 
508+     parser .add_argument (
509+         "--only_target_transformer_blocks" ,
510+         action = "store_true" ,
511+         help = "If we should only target the transformer blocks to train along with the input layer (`x_embedder`)." ,
512+     )
509513    parser .add_argument (
510514        "--guidance_scale" ,
511515        type = float ,
@@ -581,7 +585,7 @@ def parse_args(input_args=None):
581585
582586    if  args .resolution  %  8  !=  0 :
583587        raise  ValueError (
584-             "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder ." 
588+             "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer ." 
585589        )
586590
587591    return  args 
@@ -665,7 +669,12 @@ def preprocess_train(examples):
665669        conditioning_images  =  [image_transforms (image ) for  image  in  conditioning_images ]
666670        examples ["pixel_values" ] =  images 
667671        examples ["conditioning_pixel_values" ] =  conditioning_images 
668-         examples ["captions" ] =  list (examples [args .caption_column ])
672+ 
673+         is_caption_list  =  isinstance (examples [args .caption_column ][0 ], list )
674+         if  is_caption_list :
675+             examples ["captions" ] =  [max (example , key = len ) for  example  in  examples [args .caption_column ]]
676+         else :
677+             examples ["captions" ] =  list (examples [args .caption_column ])
669678
670679        return  examples 
671680
@@ -765,7 +774,8 @@ def main(args):
765774        subfolder = "scheduler" ,
766775    )
767776    noise_scheduler_copy  =  copy .deepcopy (noise_scheduler )
768-     flux_transformer .requires_grad_ (True )
777+     if  not  args .only_target_transformer_blocks :
778+         flux_transformer .requires_grad_ (True )
769779    vae .requires_grad_ (False )
770780
771781    # cast down and move to the CPU 
@@ -797,6 +807,12 @@ def main(args):
797807    assert  torch .all (flux_transformer .x_embedder .weight [:, initial_input_channels :].data  ==  0 )
798808    flux_transformer .register_to_config (in_channels = initial_input_channels  *  2 , out_channels = initial_input_channels )
799809
810+     if  args .only_target_transformer_blocks :
811+         flux_transformer .x_embedder .requires_grad_ (True )
812+         for  name , module  in  flux_transformer .named_modules ():
813+             if  "transformer_blocks"  in  name :
814+                 module .requires_grad_ (True )
815+ 
800816    def  unwrap_model (model ):
801817        model  =  accelerator .unwrap_model (model )
802818        model  =  model ._orig_mod  if  is_compiled_module (model ) else  model 
@@ -974,6 +990,32 @@ def load_model_hook(models, input_dir):
974990    else :
975991        initial_global_step  =  0 
976992
993+     if  accelerator .is_main_process  and  args .report_to  ==  "wandb"  and  args .log_dataset_samples :
994+         logger .info ("Logging some dataset samples." )
995+         formatted_images  =  []
996+         formatted_control_images  =  []
997+         all_prompts  =  []
998+         for  i , batch  in  enumerate (train_dataloader ):
999+             images  =  (batch ["pixel_values" ] +  1 ) /  2 
1000+             control_images  =  (batch ["conditioning_pixel_values" ] +  1 ) /  2 
1001+             prompts  =  batch ["captions" ]
1002+ 
1003+             if  len (formatted_images ) >  10 :
1004+                 break 
1005+ 
1006+             for  img , control_img , prompt  in  zip (images , control_images , prompts ):
1007+                 formatted_images .append (img )
1008+                 formatted_control_images .append (control_img )
1009+                 all_prompts .append (prompt )
1010+ 
1011+         logged_artifacts  =  []
1012+         for  img , control_img , prompt  in  zip (formatted_images , formatted_control_images , all_prompts ):
1013+             logged_artifacts .append (wandb .Image (control_img , caption = "Conditioning" ))
1014+             logged_artifacts .append (wandb .Image (img , caption = prompt ))
1015+ 
1016+         wandb_tracker  =  [tracker  for  tracker  in  accelerator .trackers  if  tracker .name  ==  "wandb" ]
1017+         wandb_tracker [0 ].log ({"dataset_samples" : logged_artifacts })
1018+ 
9771019    progress_bar  =  tqdm (
9781020        range (0 , args .max_train_steps ),
9791021        initial = initial_global_step ,
0 commit comments