Skip to content

Commit 7521fec

Browse files
committed
updates
1 parent 67bc7e4 commit 7521fec

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

examples/control-lora/train_control_lora_flux.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import accelerate
2727
import numpy as np
2828
import torch
29-
import torch.utils.checkpoint
3029
import transformers
3130
from accelerate import Accelerator
3231
from accelerate.logging import get_logger
@@ -49,7 +48,7 @@
4948
compute_loss_weighting_for_sd3,
5049
free_memory,
5150
)
52-
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
51+
from diffusers.utils import check_min_version, is_wandb_available, load_image, make_image_grid
5352
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
5453
from diffusers.utils.torch_utils import is_compiled_module
5554

@@ -63,17 +62,17 @@
6362
logger = get_logger(__name__)
6463

6564

66-
def encode_image(pixels: torch.Tensor, vae: torch.nn.Module, weight_dtype):
65+
def encode_images(pixels: torch.Tensor, vae: torch.nn.Module, weight_dtype):
6766
pixel_latents = vae.encode(pixels.to(vae.dtype)).latent_dist.sample()
6867
pixel_latents = (pixel_latents - vae.config.shift_factor) * vae.config.scaling_factor
6968
return pixel_latents.to(weight_dtype)
7069

7170

7271
def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_final_validation=False):
7372
logger.info("Running validation... ")
74-
flux_transformer = accelerator.unwrap_model(flux_transformer)
7573

7674
if not is_final_validation:
75+
flux_transformer = accelerator.unwrap_model(flux_transformer)
7776
pipeline = FluxControlPipeline.from_pretrained(
7877
args.pretrained_model_name_or_path,
7978
transformer=flux_transformer,
@@ -83,12 +82,16 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
8382
transformer = FluxTransformer2DModel.from_pretrained(
8483
args.pretrained_model_name_or_path, subfolder="transformer", torch_dtype=weight_dtype
8584
)
85+
initial_channels = transformer.config.in_channels
8686
pipeline = FluxControlPipeline.from_pretrained(
8787
args.pretrained_model_name_or_path,
8888
transformer=transformer,
8989
torch_dtype=weight_dtype,
9090
)
9191
pipeline.load_lora_weights(args.output_dir)
92+
assert (
93+
pipeline.transformer.config.in_channels == initial_channels * 2
94+
), f"{pipeline.transformer.config.in_channels=}"
9295

9396
pipeline.to(accelerator.device)
9497
pipeline.set_progress_bar_config(disable=True)
@@ -119,8 +122,6 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
119122
autocast_ctx = torch.autocast(accelerator.device.type, weight_dtype)
120123

121124
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
122-
from diffusers.utils import load_image
123-
124125
validation_image = load_image(validation_image)
125126
# maybe need to inference on 1024 to get a good image
126127
validation_image = validation_image.resize((args.resolution, args.resolution))
@@ -136,6 +137,9 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
136137
num_inference_steps=50,
137138
guidance_scale=args.guidance_scale,
138139
generator=generator,
140+
max_sequence_length=512,
141+
height=1024,
142+
width=1204,
139143
).images[0]
140144
image = image.resize((args.resolution, args.resolution))
141145
images.append(image)
@@ -824,7 +828,7 @@ def main(args):
824828
new_linear.bias.copy_(flux_transformer.x_embedder.bias)
825829
flux_transformer.x_embedder = new_linear
826830

827-
assert torch.all(new_linear.weight[:, initial_input_channels:].data == 0)
831+
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
828832
flux_transformer.register_to_config(in_channels=initial_input_channels * 2)
829833

830834
if args.lora_layers is not None:
@@ -963,10 +967,8 @@ def load_model_hook(models, input_dir):
963967

964968
# Optimization parameters
965969
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, flux_transformer.parameters()))
966-
transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
967-
params_to_optimize = [transformer_parameters_with_lr]
968970
optimizer = optimizer_class(
969-
params_to_optimize,
971+
transformer_lora_parameters,
970972
lr=args.learning_rate,
971973
betas=(args.adam_beta1, args.adam_beta2),
972974
weight_decay=args.adam_weight_decay,
@@ -1101,8 +1103,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
11011103
with accelerator.accumulate(flux_transformer):
11021104
# Convert images to latent space
11031105
# vae encode
1104-
pixel_latents = encode_image(batch["pixel_values"], vae.to(accelerator.device), weight_dtype)
1105-
control_latents = encode_image(
1106+
pixel_latents = encode_images(batch["pixel_values"], vae.to(accelerator.device), weight_dtype)
1107+
control_latents = encode_images(
11061108
batch["conditioning_pixel_values"], vae.to(accelerator.device), weight_dtype
11071109
)
11081110
# offload vae to CPU.
@@ -1273,7 +1275,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
12731275
image_logs = None
12741276
if args.validation_prompt is not None:
12751277
image_logs = log_validation(
1276-
flux_transformer=flux_transformer,
1278+
flux_transformer=None,
12771279
args=args,
12781280
accelerator=accelerator,
12791281
weight_dtype=weight_dtype,

0 commit comments

Comments
 (0)