Skip to content

Commit 4b9f1c7

Browse files
Add correct number of channels when resuming from checkpoint for Flux Control LoRa training (#10422)
* Add correct number of channels when resuming from checkpoint * Fix Formatting
1 parent 91008aa commit 4b9f1c7

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

examples/flux-control/train_control_lora_flux.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -923,11 +923,28 @@ def load_model_hook(models, input_dir):
923923
transformer_ = model
924924
else:
925925
raise ValueError(f"unexpected save model: {model.__class__}")
926-
927926
else:
928927
transformer_ = FluxTransformer2DModel.from_pretrained(
929928
args.pretrained_model_name_or_path, subfolder="transformer"
930929
).to(accelerator.device, weight_dtype)
930+
931+
# Handle input dimension doubling before adding adapter
932+
with torch.no_grad():
933+
initial_input_channels = transformer_.config.in_channels
934+
new_linear = torch.nn.Linear(
935+
transformer_.x_embedder.in_features * 2,
936+
transformer_.x_embedder.out_features,
937+
bias=transformer_.x_embedder.bias is not None,
938+
dtype=transformer_.dtype,
939+
device=transformer_.device,
940+
)
941+
new_linear.weight.zero_()
942+
new_linear.weight[:, :initial_input_channels].copy_(transformer_.x_embedder.weight)
943+
if transformer_.x_embedder.bias is not None:
944+
new_linear.bias.copy_(transformer_.x_embedder.bias)
945+
transformer_.x_embedder = new_linear
946+
transformer_.register_to_config(in_channels=initial_input_channels * 2)
947+
931948
transformer_.add_adapter(transformer_lora_config)
932949

933950
lora_state_dict = FluxControlPipeline.lora_state_dict(input_dir)

0 commit comments

Comments
 (0)