Skip to content

Commit ddc6164

Browse files
authored
Merge branch 'main' into feat-taylorseer
2 parents d06c6bc + 1b91856 commit ddc6164

File tree

15 files changed

+729
-96
lines changed

15 files changed

+729
-96
lines changed

docs/source/en/api/pipelines/flux2.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ specific language governing permissions and limitations under the License.
1919

2020
Flux.2 is the recent series of image generation models from Black Forest Labs, preceded by the [Flux.1](./flux.md) series. It is an entirely new model with a new architecture and pre-training done from scratch!
2121

22-
Original model checkpoints for Flux can be found [here](https://huggingface.co/black-forest-labs). Original inference code can be found [here](https://github.com/black-forest-labs/flux2-dev).
22+
Original model checkpoints for Flux can be found [here](https://huggingface.co/black-forest-labs). Original inference code can be found [here](https://github.com/black-forest-labs/flux2).
2323

2424
> [!TIP]
2525
> Flux2 can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more.

examples/dreambooth/README_flux2.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize image generation models given just a few (3~5) images of a subject/concept.
44

5-
The `train_dreambooth_lora_flux2.py` script shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [FLUX.2 [dev]](https://github.com/black-forest-labs/flux2-dev).
5+
The `train_dreambooth_lora_flux2.py` script shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [FLUX.2 [dev]](https://github.com/black-forest-labs/flux2).
66

77
> [!NOTE]
88
> **Memory consumption**

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from huggingface_hub import create_repo, upload_folder
3838
from packaging import version
3939
from peft import LoraConfig
40-
from peft.utils import get_peft_model_state_dict
40+
from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict
4141
from torchvision import transforms
4242
from tqdm.auto import tqdm
4343
from transformers import CLIPTextModel, CLIPTokenizer
@@ -46,7 +46,12 @@
4646
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
4747
from diffusers.optimization import get_scheduler
4848
from diffusers.training_utils import cast_training_params, compute_snr
49-
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
49+
from diffusers.utils import (
50+
check_min_version,
51+
convert_state_dict_to_diffusers,
52+
convert_unet_state_dict_to_peft,
53+
is_wandb_available,
54+
)
5055
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
5156
from diffusers.utils.import_utils import is_xformers_available
5257
from diffusers.utils.torch_utils import is_compiled_module
@@ -708,6 +713,56 @@ def collate_fn(examples):
708713
num_workers=args.dataloader_num_workers,
709714
)
710715

716+
def save_model_hook(models, weights, output_dir):
717+
if accelerator.is_main_process:
718+
unet_lora_layers_to_save = None
719+
720+
for model in models:
721+
if isinstance(model, type(unwrap_model(unet))):
722+
unet_lora_layers_to_save = get_peft_model_state_dict(model)
723+
else:
724+
raise ValueError(f"Unexpected save model: {model.__class__}")
725+
726+
# make sure to pop weight so that corresponding model is not saved again
727+
weights.pop()
728+
729+
StableDiffusionPipeline.save_lora_weights(
730+
save_directory=output_dir,
731+
unet_lora_layers=unet_lora_layers_to_save,
732+
safe_serialization=True,
733+
)
734+
735+
def load_model_hook(models, input_dir):
736+
unet_ = None
737+
738+
while len(models) > 0:
739+
model = models.pop()
740+
if isinstance(model, type(unwrap_model(unet))):
741+
unet_ = model
742+
else:
743+
raise ValueError(f"unexpected save model: {model.__class__}")
744+
745+
# returns a tuple of state dictionary and network alphas
746+
lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir)
747+
748+
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
749+
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
750+
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
751+
752+
if incompatible_keys is not None:
753+
# check only for unexpected keys
754+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
755+
# throw warning if some unexpected keys are found and continue loading
756+
if unexpected_keys:
757+
logger.warning(
758+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
759+
f" {unexpected_keys}. "
760+
)
761+
762+
# Make sure the trainable params are in float32
763+
if args.mixed_precision in ["fp16"]:
764+
cast_training_params([unet_], dtype=torch.float32)
765+
711766
# Scheduler and math around the number of training steps.
712767
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
713768
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
@@ -732,6 +787,10 @@ def collate_fn(examples):
732787
unet, optimizer, train_dataloader, lr_scheduler
733788
)
734789

790+
# Register the hooks for efficient saving and loading of LoRA weights
791+
accelerator.register_save_state_pre_hook(save_model_hook)
792+
accelerator.register_load_state_pre_hook(load_model_hook)
793+
735794
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
736795
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
737796
if args.max_train_steps is None:
@@ -906,17 +965,6 @@ def collate_fn(examples):
906965
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
907966
accelerator.save_state(save_path)
908967

909-
unwrapped_unet = unwrap_model(unet)
910-
unet_lora_state_dict = convert_state_dict_to_diffusers(
911-
get_peft_model_state_dict(unwrapped_unet)
912-
)
913-
914-
StableDiffusionPipeline.save_lora_weights(
915-
save_directory=save_path,
916-
unet_lora_layers=unet_lora_state_dict,
917-
safe_serialization=True,
918-
)
919-
920968
logger.info(f"Saved state to {save_path}")
921969

922970
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,10 @@ def timestep_embedding(t, dim, max_period=10000):
6969

7070
def forward(self, t):
7171
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
72-
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
72+
weight_dtype = self.mlp[0].weight.dtype
73+
if weight_dtype.is_floating_point:
74+
t_freq = t_freq.to(weight_dtype)
75+
t_emb = self.mlp(t_freq)
7376
return t_emb
7477

7578

@@ -126,6 +129,10 @@ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tenso
126129
dtype = query.dtype
127130
query, key = query.to(dtype), key.to(dtype)
128131

132+
# From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len]
133+
if attention_mask is not None and attention_mask.ndim == 2:
134+
attention_mask = attention_mask[:, None, None, :]
135+
129136
# Compute joint attention
130137
hidden_states = dispatch_attention_fn(
131138
query,
@@ -306,6 +313,10 @@ def __call__(self, ids: torch.Tensor):
306313
if self.freqs_cis is None:
307314
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
308315
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
316+
else:
317+
# Ensure freqs_cis are on the same device as ids
318+
if self.freqs_cis[0].device != device:
319+
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
309320

310321
result = []
311322
for i in range(len(self.axes_dims)):
@@ -317,6 +328,8 @@ def __call__(self, ids: torch.Tensor):
317328
class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
318329
_supports_gradient_checkpointing = True
319330
_no_split_modules = ["ZImageTransformerBlock"]
331+
_repeated_blocks = ["ZImageTransformerBlock"]
332+
_skip_layerwise_casting_patterns = ["t_embedder", "cap_embedder"] # precision sensitive layers
320333

321334
@register_to_config
322335
def __init__(
@@ -553,8 +566,6 @@ def forward(
553566
t = t * self.t_scale
554567
t = self.t_embedder(t)
555568

556-
adaln_input = t
557-
558569
(
559570
x,
560571
cap_feats,
@@ -572,6 +583,9 @@ def forward(
572583

573584
x = torch.cat(x, dim=0)
574585
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
586+
587+
# Match t_embedder output dtype to x for layerwise casting compatibility
588+
adaln_input = t.type_as(x)
575589
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
576590
x = list(x.split(x_item_seqlens, dim=0))
577591
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))

src/diffusers/pipelines/flux2/pipeline_flux2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -861,7 +861,6 @@ def __call__(
861861
if output_type == "latent":
862862
image = latents
863863
else:
864-
torch.save({"pred": latents}, "pred_d.pt")
865864
latents = self._unpack_latents_with_ids(latents, latent_ids)
866865

867866
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)

src/diffusers/pipelines/z_image/pipeline_z_image.py

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -165,21 +165,16 @@ def encode_prompt(
165165
self,
166166
prompt: Union[str, List[str]],
167167
device: Optional[torch.device] = None,
168-
dtype: Optional[torch.dtype] = None,
169-
num_images_per_prompt: int = 1,
170168
do_classifier_free_guidance: bool = True,
171169
negative_prompt: Optional[Union[str, List[str]]] = None,
172170
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
173171
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
174172
max_sequence_length: int = 512,
175-
lora_scale: Optional[float] = None,
176173
):
177174
prompt = [prompt] if isinstance(prompt, str) else prompt
178175
prompt_embeds = self._encode_prompt(
179176
prompt=prompt,
180177
device=device,
181-
dtype=dtype,
182-
num_images_per_prompt=num_images_per_prompt,
183178
prompt_embeds=prompt_embeds,
184179
max_sequence_length=max_sequence_length,
185180
)
@@ -193,8 +188,6 @@ def encode_prompt(
193188
negative_prompt_embeds = self._encode_prompt(
194189
prompt=negative_prompt,
195190
device=device,
196-
dtype=dtype,
197-
num_images_per_prompt=num_images_per_prompt,
198191
prompt_embeds=negative_prompt_embeds,
199192
max_sequence_length=max_sequence_length,
200193
)
@@ -206,12 +199,9 @@ def _encode_prompt(
206199
self,
207200
prompt: Union[str, List[str]],
208201
device: Optional[torch.device] = None,
209-
dtype: Optional[torch.dtype] = None,
210-
num_images_per_prompt: int = 1,
211202
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
212203
max_sequence_length: int = 512,
213204
) -> List[torch.FloatTensor]:
214-
assert num_images_per_prompt == 1
215205
device = device or self._execution_device
216206

217207
if prompt_embeds is not None:
@@ -417,8 +407,6 @@ def __call__(
417407
f"Please adjust the width to a multiple of {vae_scale}."
418408
)
419409

420-
assert self.dtype == torch.bfloat16
421-
dtype = self.dtype
422410
device = self._execution_device
423411

424412
self._guidance_scale = guidance_scale
@@ -434,10 +422,6 @@ def __call__(
434422
else:
435423
batch_size = len(prompt_embeds)
436424

437-
lora_scale = (
438-
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
439-
)
440-
441425
# If prompt_embeds is provided and prompt is None, skip encoding
442426
if prompt_embeds is not None and prompt is None:
443427
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
@@ -455,11 +439,8 @@ def __call__(
455439
do_classifier_free_guidance=self.do_classifier_free_guidance,
456440
prompt_embeds=prompt_embeds,
457441
negative_prompt_embeds=negative_prompt_embeds,
458-
dtype=dtype,
459442
device=device,
460-
num_images_per_prompt=num_images_per_prompt,
461443
max_sequence_length=max_sequence_length,
462-
lora_scale=lora_scale,
463444
)
464445

465446
# 4. Prepare latent variables
@@ -475,6 +456,14 @@ def __call__(
475456
generator,
476457
latents,
477458
)
459+
460+
# Repeat prompt_embeds for num_images_per_prompt
461+
if num_images_per_prompt > 1:
462+
prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
463+
if self.do_classifier_free_guidance and negative_prompt_embeds:
464+
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
465+
466+
actual_batch_size = batch_size * num_images_per_prompt
478467
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
479468

480469
# 5. Prepare timesteps
@@ -523,12 +512,12 @@ def __call__(
523512
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
524513

525514
if apply_cfg:
526-
latents_typed = latents if latents.dtype == dtype else latents.to(dtype)
515+
latents_typed = latents.to(self.transformer.dtype)
527516
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
528517
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
529518
timestep_model_input = timestep.repeat(2)
530519
else:
531-
latent_model_input = latents if latents.dtype == dtype else latents.to(dtype)
520+
latent_model_input = latents.to(self.transformer.dtype)
532521
prompt_embeds_model_input = prompt_embeds
533522
timestep_model_input = timestep
534523

@@ -543,11 +532,11 @@ def __call__(
543532

544533
if apply_cfg:
545534
# Perform CFG
546-
pos_out = model_out_list[:batch_size]
547-
neg_out = model_out_list[batch_size:]
535+
pos_out = model_out_list[:actual_batch_size]
536+
neg_out = model_out_list[actual_batch_size:]
548537

549538
noise_pred = []
550-
for j in range(batch_size):
539+
for j in range(actual_batch_size):
551540
pos = pos_out[j].float()
552541
neg = neg_out[j].float()
553542

@@ -588,11 +577,11 @@ def __call__(
588577
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
589578
progress_bar.update()
590579

591-
latents = latents.to(dtype)
592580
if output_type == "latent":
593581
image = latents
594582

595583
else:
584+
latents = latents.to(self.vae.dtype)
596585
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
597586

598587
image = self.vae.decode(latents, return_dict=False)[0]

src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,22 @@ def multistep_dpm_solver_second_order_update(
429429
return x_t
430430

431431
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
432-
def index_for_timestep(self, timestep, schedule_timesteps=None):
432+
def index_for_timestep(
433+
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
434+
) -> int:
435+
"""
436+
Find the index for a given timestep in the schedule.
437+
438+
Args:
439+
timestep (`int` or `torch.Tensor`):
440+
The timestep for which to find the index.
441+
schedule_timesteps (`torch.Tensor`, *optional*):
442+
The timestep schedule to search in. If `None`, uses `self.timesteps`.
443+
444+
Returns:
445+
`int`:
446+
The index of the timestep in the schedule.
447+
"""
433448
if schedule_timesteps is None:
434449
schedule_timesteps = self.timesteps
435450

@@ -452,6 +467,10 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
452467
def _init_step_index(self, timestep):
453468
"""
454469
Initialize the step_index counter for the scheduler.
470+
471+
Args:
472+
timestep (`int` or `torch.Tensor`):
473+
The current timestep for which to initialize the step index.
455474
"""
456475

457476
if self.begin_index is None:

0 commit comments

Comments
 (0)