@@ -923,11 +923,28 @@ def load_model_hook(models, input_dir):
923923                        transformer_  =  model 
924924                    else :
925925                        raise  ValueError (f"unexpected save model: { model .__class__ }  " )
926- 
927926            else :
928927                transformer_  =  FluxTransformer2DModel .from_pretrained (
929928                    args .pretrained_model_name_or_path , subfolder = "transformer" 
930929                ).to (accelerator .device , weight_dtype )
930+ 
931+                 # Handle input dimension doubling before adding adapter 
932+                 with  torch .no_grad ():
933+                     initial_input_channels  =  transformer_ .config .in_channels 
934+                     new_linear  =  torch .nn .Linear (
935+                         transformer_ .x_embedder .in_features  *  2 ,
936+                         transformer_ .x_embedder .out_features ,
937+                         bias = transformer_ .x_embedder .bias  is  not   None ,
938+                         dtype = transformer_ .dtype ,
939+                         device = transformer_ .device ,
940+                     )
941+                     new_linear .weight .zero_ ()
942+                     new_linear .weight [:, :initial_input_channels ].copy_ (transformer_ .x_embedder .weight )
943+                     if  transformer_ .x_embedder .bias  is  not   None :
944+                         new_linear .bias .copy_ (transformer_ .x_embedder .bias )
945+                     transformer_ .x_embedder  =  new_linear 
946+                     transformer_ .register_to_config (in_channels = initial_input_channels  *  2 )
947+ 
931948                transformer_ .add_adapter (transformer_lora_config )
932949
933950            lora_state_dict  =  FluxControlPipeline .lora_state_dict (input_dir )
0 commit comments