@@ -923,11 +923,28 @@ def load_model_hook(models, input_dir):
923923 transformer_ = model
924924 else :
925925 raise ValueError (f"unexpected save model: { model .__class__ } " )
926-
927926 else :
928927 transformer_ = FluxTransformer2DModel .from_pretrained (
929928 args .pretrained_model_name_or_path , subfolder = "transformer"
930929 ).to (accelerator .device , weight_dtype )
930+
931+ # Handle input dimension doubling before adding adapter
932+ with torch .no_grad ():
933+ initial_input_channels = transformer_ .config .in_channels
934+ new_linear = torch .nn .Linear (
935+ transformer_ .x_embedder .in_features * 2 ,
936+ transformer_ .x_embedder .out_features ,
937+ bias = transformer_ .x_embedder .bias is not None ,
938+ dtype = transformer_ .dtype ,
939+ device = transformer_ .device ,
940+ )
941+ new_linear .weight .zero_ ()
942+ new_linear .weight [:, :initial_input_channels ].copy_ (transformer_ .x_embedder .weight )
943+ if transformer_ .x_embedder .bias is not None :
944+ new_linear .bias .copy_ (transformer_ .x_embedder .bias )
945+ transformer_ .x_embedder = new_linear
946+ transformer_ .register_to_config (in_channels = initial_input_channels * 2 )
947+
931948 transformer_ .add_adapter (transformer_lora_config )
932949
933950 lora_state_dict = FluxControlPipeline .lora_state_dict (input_dir )
0 commit comments