Skip to content

Commit 90708fa

Browse files
committed
fixes
1 parent a6158d7 commit 90708fa

File tree

2 files changed

+56
-41
lines changed

2 files changed

+56
-41
lines changed

examples/control-lora/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,6 @@ accelerate launch train_control_lora_flux.py \
3232
--push_to_hub
3333
```
3434

35+
`openpose.png` comes from [here](https://huggingface.co/Adapter/t2iadapter/resolve/main/openpose.png).
36+
3537
You need to install `diffusers` from the branch of [this PR](https://github.com/huggingface/diffusers/pull/9999).

examples/control-lora/train_control_lora_flux.py

Lines changed: 54 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import accelerate
2626
import numpy as np
2727
import torch
28-
import torch.nn.functional as F
2928
import torch.utils.checkpoint
3029
import transformers
3130
from accelerate import Accelerator
@@ -43,7 +42,12 @@
4342
import diffusers
4443
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxTransformer2DModel
4544
from diffusers.optimization import get_scheduler
46-
from diffusers.training_utils import cast_training_params, compute_density_for_timestep_sampling, free_memory
45+
from diffusers.training_utils import (
46+
cast_training_params,
47+
compute_density_for_timestep_sampling,
48+
compute_loss_weighting_for_sd3,
49+
free_memory,
50+
)
4751
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
4852
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
4953
from diffusers.utils.torch_utils import is_compiled_module
@@ -550,7 +554,7 @@ def parse_args(input_args=None):
550554
parser.add_argument(
551555
"--weighting_scheme",
552556
type=str,
553-
default="logit_normal",
557+
default="none",
554558
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
555559
help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
556560
)
@@ -566,11 +570,6 @@ def parse_args(input_args=None):
566570
default=1.29,
567571
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
568572
)
569-
parser.add_argument(
570-
"--enable_model_cpu_offload",
571-
action="store_true",
572-
help="Enable model cpu offload and save memory.",
573-
)
574573

575574
if input_args is not None:
576575
args = parser.parse_args(input_args)
@@ -672,7 +671,8 @@ def prepare_train_dataset(dataset, accelerator):
672671
[
673672
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
674673
transforms.CenterCrop(args.resolution),
675-
transforms.Lambda(lambda x: x / 127.5 - 1.0),
674+
transforms.ToTensor(),
675+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
676676
]
677677
)
678678

@@ -735,7 +735,7 @@ def main(args):
735735

736736
# Disable AMP for MPS. A technique for accelerating machine learning computations on iOS and macOS devices.
737737
if torch.backends.mps.is_available():
738-
print("MPS is enabled. Disabling AMP.")
738+
logger.info("MPS is enabled. Disabling AMP.")
739739
accelerator.native_amp = False
740740

741741
# Make one log on every process with the configuration for debugging.
@@ -776,6 +776,7 @@ def main(args):
776776
revision=args.revision,
777777
variant=args.variant,
778778
)
779+
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
779780
flux_transformer = FluxTransformer2DModel.from_pretrained(
780781
args.pretrained_model_name_or_path,
781782
subfolder="transformer",
@@ -817,6 +818,8 @@ def main(args):
817818
new_linear.weight[:, :initial_input_channels].copy_(flux_transformer.x_embedder.weight)
818819
new_linear.bias.copy_(flux_transformer.x_embedder.bias)
819820
flux_transformer.x_embedder = new_linear
821+
822+
assert torch.all(new_linear.weight[:, initial_input_channels:].data == 0)
820823
flux_transformer.register_to_config(in_channels=initial_input_channels * 2)
821824

822825
if args.lora_layers is not None:
@@ -1092,24 +1095,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
10921095
# offload vae to CPU.
10931096
vae.cpu()
10941097

1095-
# pack the latents.
1096-
packed_pixel_latents = FluxControlPipeline._pack_latents(
1097-
pixel_latents,
1098-
batch_size=pixel_latents.shape[0],
1099-
num_channels_latents=pixel_latents.shape[1],
1100-
height=pixel_latents.shape[2],
1101-
width=pixel_latents.shape[3],
1102-
)
1103-
packed_control_latents = FluxControlPipeline._pack_latents(
1104-
pixel_latents,
1105-
batch_size=control_latents.shape[0],
1106-
num_channels_latents=control_latents.shape[1],
1107-
height=control_latents.shape[2],
1108-
width=control_latents.shape[3],
1109-
)
1110-
# concate across channels.
1111-
latent_model_input = torch.cat([packed_pixel_latents, packed_control_latents], dim=2)
1112-
11131098
# Sample a random timestep for each image
11141099
# for weighting schemes where we sample timesteps non-uniformly
11151100
bsz = pixel_latents.shape[0]
@@ -1122,25 +1107,37 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
11221107
mode_scale=args.mode_scale,
11231108
)
11241109
indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
1125-
timesteps = noise_scheduler_copy.timesteps[indices].to(device=latent_model_input.device)
1110+
timesteps = noise_scheduler_copy.timesteps[indices].to(device=pixel_latents.device)
11261111

11271112
# Add noise according to flow matching.
1128-
sigmas = get_sigmas(timesteps, n_dim=latent_model_input.ndim, dtype=latent_model_input.dtype)
1129-
noisy_model_input = (1.0 - sigmas) * latent_model_input + sigmas * noise
1113+
sigmas = get_sigmas(timesteps, n_dim=pixel_latents.ndim, dtype=pixel_latents.dtype)
1114+
noisy_model_input = (1.0 - sigmas) * pixel_latents + sigmas * noise
1115+
# Concatenate across channels.
1116+
# Question: Should we concatenate before adding noise?
1117+
concatenated_noisy_model_input = torch.cat([noisy_model_input, control_latents], dim=1)
1118+
1119+
# pack the latents.
1120+
packed_noisy_model_input = FluxControlPipeline._pack_latents(
1121+
concatenated_noisy_model_input,
1122+
batch_size=bsz,
1123+
num_channels_latents=concatenated_noisy_model_input.shape[1],
1124+
height=concatenated_noisy_model_input.shape[2],
1125+
width=concatenated_noisy_model_input.shape[3],
1126+
)
11301127

11311128
# latent image ids for RoPE.
11321129
latent_image_ids = FluxControlPipeline._prepare_latent_image_ids(
1133-
pixel_latents.shape[0],
1134-
pixel_latents.shape[2] // 2,
1135-
pixel_latents.shape[3] // 2,
1130+
bsz,
1131+
concatenated_noisy_model_input.shape[2] // 2,
1132+
concatenated_noisy_model_input.shape[3] // 2,
11361133
accelerator.device,
11371134
weight_dtype,
11381135
)
11391136

11401137
# handle guidance
11411138
if flux_transformer.config.guidance_embeds:
11421139
guidance_vec = torch.full(
1143-
(noisy_model_input.shape[0],),
1140+
(bsz,),
11441141
args.guidance_scale,
11451142
device=noisy_model_input.device,
11461143
dtype=weight_dtype,
@@ -1152,12 +1149,14 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
11521149
captions = batch["captions"]
11531150
text_encoding_pipeline = text_encoding_pipeline.to("cuda")
11541151
with torch.no_grad():
1155-
prompt_embeds, pooled_prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(captions)
1152+
prompt_embeds, pooled_prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
1153+
captions, prompt_2=None
1154+
)
11561155
text_encoding_pipeline = text_encoding_pipeline.to("cuda")
11571156

11581157
# Predict.
1159-
noise_pred = flux_transformer(
1160-
hidden_states=noisy_model_input,
1158+
model_pred = flux_transformer(
1159+
hidden_states=packed_noisy_model_input,
11611160
timestep=timesteps / 1000,
11621161
guidance=guidance_vec,
11631162
pooled_projections=pooled_prompt_embeds,
@@ -1166,10 +1165,24 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
11661165
img_ids=latent_image_ids,
11671166
return_dict=False,
11681167
)[0]
1169-
1170-
loss = F.mse_loss(noise_pred.float(), (noise - pixel_latents).float(), reduction="mean")
1168+
model_pred = FluxControlPipeline._unpack_latents(
1169+
model_pred,
1170+
height=noisy_model_input.shape[2] * vae_scale_factor,
1171+
width=noisy_model_input.shape[3] * vae_scale_factor,
1172+
vae_scale_factor=vae_scale_factor,
1173+
)
1174+
# these weighting schemes use a uniform timestep sampling
1175+
# and instead post-weight the loss
1176+
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
1177+
1178+
# flow-matching loss
1179+
target = noise - pixel_latents
1180+
loss = torch.mean(
1181+
(weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
1182+
1,
1183+
)
1184+
loss = loss.mean()
11711185
accelerator.backward(loss)
1172-
# Check if the gradient of each model parameter contains NaN
11731186

11741187
if accelerator.sync_gradients:
11751188
params_to_clip = flux_transformer.parameters()

0 commit comments

Comments
 (0)