2626import  accelerate 
2727import  numpy  as  np 
2828import  torch 
29- import  torch .utils .checkpoint 
3029import  transformers 
3130from  accelerate  import  Accelerator 
3231from  accelerate .logging  import  get_logger 
4948    compute_loss_weighting_for_sd3 ,
5049    free_memory ,
5150)
52- from  diffusers .utils  import  check_min_version , is_wandb_available , make_image_grid 
51+ from  diffusers .utils  import  check_min_version , is_wandb_available , load_image ,  make_image_grid 
5352from  diffusers .utils .hub_utils  import  load_or_create_model_card , populate_model_card 
5453from  diffusers .utils .torch_utils  import  is_compiled_module 
5554
6362logger  =  get_logger (__name__ )
6463
6564
66- def  encode_image (pixels : torch .Tensor , vae : torch .nn .Module , weight_dtype ):
65+ def  encode_images (pixels : torch .Tensor , vae : torch .nn .Module , weight_dtype ):
6766    pixel_latents  =  vae .encode (pixels .to (vae .dtype )).latent_dist .sample ()
6867    pixel_latents  =  (pixel_latents  -  vae .config .shift_factor ) *  vae .config .scaling_factor 
6968    return  pixel_latents .to (weight_dtype )
7069
7170
7271def  log_validation (flux_transformer , args , accelerator , weight_dtype , step , is_final_validation = False ):
7372    logger .info ("Running validation... " )
74-     flux_transformer  =  accelerator .unwrap_model (flux_transformer )
7573
7674    if  not  is_final_validation :
75+         flux_transformer  =  accelerator .unwrap_model (flux_transformer )
7776        pipeline  =  FluxControlPipeline .from_pretrained (
7877            args .pretrained_model_name_or_path ,
7978            transformer = flux_transformer ,
@@ -83,12 +82,16 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
8382        transformer  =  FluxTransformer2DModel .from_pretrained (
8483            args .pretrained_model_name_or_path , subfolder = "transformer" , torch_dtype = weight_dtype 
8584        )
85+         initial_channels  =  transformer .config .in_channels 
8686        pipeline  =  FluxControlPipeline .from_pretrained (
8787            args .pretrained_model_name_or_path ,
8888            transformer = transformer ,
8989            torch_dtype = weight_dtype ,
9090        )
9191        pipeline .load_lora_weights (args .output_dir )
92+         assert  (
93+             pipeline .transformer .config .in_channels  ==  initial_channels  *  2 
94+         ), f"{ pipeline .transformer .config .in_channels = }  
9295
9396    pipeline .to (accelerator .device )
9497    pipeline .set_progress_bar_config (disable = True )
@@ -119,8 +122,6 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
119122        autocast_ctx  =  torch .autocast (accelerator .device .type , weight_dtype )
120123
121124    for  validation_prompt , validation_image  in  zip (validation_prompts , validation_images ):
122-         from  diffusers .utils  import  load_image 
123- 
124125        validation_image  =  load_image (validation_image )
125126        # maybe need to inference on 1024 to get a good image 
126127        validation_image  =  validation_image .resize ((args .resolution , args .resolution ))
@@ -136,6 +137,9 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
136137                    num_inference_steps = 50 ,
137138                    guidance_scale = args .guidance_scale ,
138139                    generator = generator ,
140+                     max_sequence_length = 512 ,
141+                     height = 1024 ,
142+                     width = 1204 ,
139143                ).images [0 ]
140144            image  =  image .resize ((args .resolution , args .resolution ))
141145            images .append (image )
@@ -824,7 +828,7 @@ def main(args):
824828            new_linear .bias .copy_ (flux_transformer .x_embedder .bias )
825829        flux_transformer .x_embedder  =  new_linear 
826830
827-     assert  torch .all (new_linear .weight [:, initial_input_channels :].data  ==  0 )
831+     assert  torch .all (flux_transformer . x_embedder .weight [:, initial_input_channels :].data  ==  0 )
828832    flux_transformer .register_to_config (in_channels = initial_input_channels  *  2 )
829833
830834    if  args .lora_layers  is  not None :
@@ -963,10 +967,8 @@ def load_model_hook(models, input_dir):
963967
964968    # Optimization parameters 
965969    transformer_lora_parameters  =  list (filter (lambda  p : p .requires_grad , flux_transformer .parameters ()))
966-     transformer_parameters_with_lr  =  {"params" : transformer_lora_parameters , "lr" : args .learning_rate }
967-     params_to_optimize  =  [transformer_parameters_with_lr ]
968970    optimizer  =  optimizer_class (
969-         params_to_optimize ,
971+         transformer_lora_parameters ,
970972        lr = args .learning_rate ,
971973        betas = (args .adam_beta1 , args .adam_beta2 ),
972974        weight_decay = args .adam_weight_decay ,
@@ -1101,8 +1103,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
11011103            with  accelerator .accumulate (flux_transformer ):
11021104                # Convert images to latent space 
11031105                # vae encode 
1104-                 pixel_latents  =  encode_image (batch ["pixel_values" ], vae .to (accelerator .device ), weight_dtype )
1105-                 control_latents  =  encode_image (
1106+                 pixel_latents  =  encode_images (batch ["pixel_values" ], vae .to (accelerator .device ), weight_dtype )
1107+                 control_latents  =  encode_images (
11061108                    batch ["conditioning_pixel_values" ], vae .to (accelerator .device ), weight_dtype 
11071109                )
11081110                # offload vae to CPU. 
@@ -1273,7 +1275,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
12731275        image_logs  =  None 
12741276        if  args .validation_prompt  is  not None :
12751277            image_logs  =  log_validation (
1276-                 flux_transformer = flux_transformer ,
1278+                 flux_transformer = None ,
12771279                args = args ,
12781280                accelerator = accelerator ,
12791281                weight_dtype = weight_dtype ,
0 commit comments