From 1583a5631251632e17bd4fda54d894d6ed740e11 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 24 Dec 2024 08:53:13 +0530 Subject: [PATCH 1/2] fix: registration of out_channels in the control flux scripts. --- examples/flux-control/train_control_flux.py | 2 +- examples/flux-control/train_control_lora_flux.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/flux-control/train_control_flux.py b/examples/flux-control/train_control_flux.py index 1432e346f0ce..5ac76d793fdb 100644 --- a/examples/flux-control/train_control_flux.py +++ b/examples/flux-control/train_control_flux.py @@ -795,7 +795,7 @@ def main(args): flux_transformer.x_embedder = new_linear assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0) - flux_transformer.register_to_config(in_channels=initial_input_channels * 2) + flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels) def unwrap_model(model): model = accelerator.unwrap_model(model) diff --git a/examples/flux-control/train_control_lora_flux.py b/examples/flux-control/train_control_lora_flux.py index 6d84e81d810a..7c4c481cedfe 100644 --- a/examples/flux-control/train_control_lora_flux.py +++ b/examples/flux-control/train_control_lora_flux.py @@ -830,7 +830,7 @@ def main(args): flux_transformer.x_embedder = new_linear assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0) - flux_transformer.register_to_config(in_channels=initial_input_channels * 2) + flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels) if args.train_norm_layers: for name, param in flux_transformer.named_parameters(): From e8b865544e2043a2d2c5f56256187f33006f836a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 24 Dec 2024 20:59:29 +0530 Subject: [PATCH 2/2] free memory. --- examples/flux-control/train_control_flux.py | 5 +++++ examples/flux-control/train_control_lora_flux.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/examples/flux-control/train_control_flux.py b/examples/flux-control/train_control_flux.py index 5ac76d793fdb..35f9a5f80342 100644 --- a/examples/flux-control/train_control_flux.py +++ b/examples/flux-control/train_control_flux.py @@ -1166,6 +1166,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): flux_transformer.to(torch.float32) flux_transformer.save_pretrained(args.output_dir) + del flux_transformer + del text_encoding_pipeline + del vae + free_memory() + # Run a final round of validation. image_logs = None if args.validation_prompt is not None: diff --git a/examples/flux-control/train_control_lora_flux.py b/examples/flux-control/train_control_lora_flux.py index 7c4c481cedfe..b176a685c963 100644 --- a/examples/flux-control/train_control_lora_flux.py +++ b/examples/flux-control/train_control_lora_flux.py @@ -1319,6 +1319,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): transformer_lora_layers=transformer_lora_layers, ) + del flux_transformer + del text_encoding_pipeline + del vae + free_memory() + # Run a final round of validation. image_logs = None if args.validation_prompt is not None: