@@ -923,11 +923,28 @@ def load_model_hook(models, input_dir):
923
923
transformer_ = model
924
924
else :
925
925
raise ValueError (f"unexpected save model: { model .__class__ } " )
926
-
927
926
else :
928
927
transformer_ = FluxTransformer2DModel .from_pretrained (
929
928
args .pretrained_model_name_or_path , subfolder = "transformer"
930
929
).to (accelerator .device , weight_dtype )
930
+
931
+ # Handle input dimension doubling before adding adapter
932
+ with torch .no_grad ():
933
+ initial_input_channels = transformer_ .config .in_channels
934
+ new_linear = torch .nn .Linear (
935
+ transformer_ .x_embedder .in_features * 2 ,
936
+ transformer_ .x_embedder .out_features ,
937
+ bias = transformer_ .x_embedder .bias is not None ,
938
+ dtype = transformer_ .dtype ,
939
+ device = transformer_ .device ,
940
+ )
941
+ new_linear .weight .zero_ ()
942
+ new_linear .weight [:, :initial_input_channels ].copy_ (transformer_ .x_embedder .weight )
943
+ if transformer_ .x_embedder .bias is not None :
944
+ new_linear .bias .copy_ (transformer_ .x_embedder .bias )
945
+ transformer_ .x_embedder = new_linear
946
+ transformer_ .register_to_config (in_channels = initial_input_channels * 2 )
947
+
931
948
transformer_ .add_adapter (transformer_lora_config )
932
949
933
950
lora_state_dict = FluxControlPipeline .lora_state_dict (input_dir )
0 commit comments