Skip to content

Commit 1583a56

Browse files
committed
fix: registration of out_channels in the control flux scripts.
1 parent 9d2c8d8 commit 1583a56

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

examples/flux-control/train_control_flux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,7 @@ def main(args):
795795
flux_transformer.x_embedder = new_linear
796796

797797
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
798-
flux_transformer.register_to_config(in_channels=initial_input_channels * 2)
798+
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
799799

800800
def unwrap_model(model):
801801
model = accelerator.unwrap_model(model)

examples/flux-control/train_control_lora_flux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -830,7 +830,7 @@ def main(args):
830830
flux_transformer.x_embedder = new_linear
831831

832832
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
833-
flux_transformer.register_to_config(in_channels=initial_input_channels * 2)
833+
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
834834

835835
if args.train_norm_layers:
836836
for name, param in flux_transformer.named_parameters():

0 commit comments

Comments
 (0)