Skip to content

Commit 8c42542

Browse files
authored
Merge branch 'main' into flux_ptxla_trillium
2 parents 2816222 + 83ba01a commit 8c42542

29 files changed

+367
-776
lines changed

examples/advanced_diffusion_training/README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,17 @@ write_basic_config()
6767
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
6868
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
6969

70+
Lastly, we recommend logging into your HF account so that your trained LoRA is automatically uploaded to the hub:
71+
```bash
72+
huggingface-cli login
73+
```
74+
This command will prompt you for a token. Copy-paste yours from your [settings/tokens](https://huggingface.co/settings/tokens),and press Enter.
75+
76+
> [!NOTE]
77+
> In the examples below we use `wandb` to document the training runs. To do the same, make sure to install `wandb`:
78+
> `pip install wandb`
79+
> Alternatively, you can use other tools / train without reporting by modifying the flag `--report_to="wandb"`.
80+
7081
### Pivotal Tuning
7182
**Training with text encoder(s)**
7283

examples/advanced_diffusion_training/README_flux.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,17 @@ write_basic_config()
6565
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
6666
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
6767

68+
Lastly, we recommend logging into your HF account so that your trained LoRA is automatically uploaded to the hub:
69+
```bash
70+
huggingface-cli login
71+
```
72+
This command will prompt you for a token. Copy-paste yours from your [settings/tokens](https://huggingface.co/settings/tokens),and press Enter.
73+
74+
> [!NOTE]
75+
> In the examples below we use `wandb` to document the training runs. To do the same, make sure to install `wandb`:
76+
> `pip install wandb`
77+
> Alternatively, you can use other tools / train without reporting by modifying the flag `--report_to="wandb"`.
78+
6879
### Target Modules
6980
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
7081
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore

examples/community/pipeline_flux_differential_img2img.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -875,10 +875,10 @@ def __call__(
875875
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
876876
mu = calculate_shift(
877877
image_seq_len,
878-
self.scheduler.config.base_image_seq_len,
879-
self.scheduler.config.max_image_seq_len,
880-
self.scheduler.config.base_shift,
881-
self.scheduler.config.max_shift,
878+
self.scheduler.config.get("base_image_seq_len", 256),
879+
self.scheduler.config.get("max_image_seq_len", 4096),
880+
self.scheduler.config.get("base_shift", 0.5),
881+
self.scheduler.config.get("max_shift", 1.16),
882882
)
883883
timesteps, num_inference_steps = retrieve_timesteps(
884884
self.scheduler,

examples/community/pipeline_flux_rf_inversion.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -820,10 +820,10 @@ def __call__(
820820
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
821821
mu = calculate_shift(
822822
image_seq_len,
823-
self.scheduler.config.base_image_seq_len,
824-
self.scheduler.config.max_image_seq_len,
825-
self.scheduler.config.base_shift,
826-
self.scheduler.config.max_shift,
823+
self.scheduler.config.get("base_image_seq_len", 256),
824+
self.scheduler.config.get("max_image_seq_len", 4096),
825+
self.scheduler.config.get("base_shift", 0.5),
826+
self.scheduler.config.get("max_shift", 1.16),
827827
)
828828
timesteps, num_inference_steps = retrieve_timesteps(
829829
self.scheduler,
@@ -990,10 +990,10 @@ def invert(
990990
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
991991
mu = calculate_shift(
992992
image_seq_len,
993-
self.scheduler.config.base_image_seq_len,
994-
self.scheduler.config.max_image_seq_len,
995-
self.scheduler.config.base_shift,
996-
self.scheduler.config.max_shift,
993+
self.scheduler.config.get("base_image_seq_len", 256),
994+
self.scheduler.config.get("max_image_seq_len", 4096),
995+
self.scheduler.config.get("base_shift", 0.5),
996+
self.scheduler.config.get("max_shift", 1.16),
997997
)
998998
timesteps, num_inversion_steps = retrieve_timesteps(
999999
self.scheduler,

examples/community/pipeline_flux_with_cfg.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
"""
6565

6666

67+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
6768
def calculate_shift(
6869
image_seq_len,
6970
base_seq_len: int = 256,
@@ -755,10 +756,10 @@ def __call__(
755756
image_seq_len = latents.shape[1]
756757
mu = calculate_shift(
757758
image_seq_len,
758-
self.scheduler.config.base_image_seq_len,
759-
self.scheduler.config.max_image_seq_len,
760-
self.scheduler.config.base_shift,
761-
self.scheduler.config.max_shift,
759+
self.scheduler.config.get("base_image_seq_len", 256),
760+
self.scheduler.config.get("max_image_seq_len", 4096),
761+
self.scheduler.config.get("base_shift", 0.5),
762+
self.scheduler.config.get("max_shift", 1.16),
762763
)
763764
timesteps, num_inference_steps = retrieve_timesteps(
764765
self.scheduler,

examples/community/rerender_a_video.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ def __call__(
632632
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
633633
instead.
634634
frames (`List[np.ndarray]` or `torch.Tensor`): The input images to be used as the starting point for the image generation process.
635-
control_frames (`List[np.ndarray]` or `torch.Tensor`): The ControlNet input images condition to provide guidance to the `unet` for generation.
635+
control_frames (`List[np.ndarray]` or `torch.Tensor` or `Callable`): The ControlNet input images condition to provide guidance to the `unet` for generation or any callable object to convert frame to control_frame.
636636
strength ('float'): SDEdit strength.
637637
num_inference_steps (`int`, *optional*, defaults to 50):
638638
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -789,7 +789,7 @@ def __call__(
789789
# Currently we only support single control
790790
if isinstance(controlnet, ControlNetModel):
791791
control_image = self.prepare_control_image(
792-
image=control_frames[0],
792+
image=control_frames(frames[0]) if callable(control_frames) else control_frames[0],
793793
width=width,
794794
height=height,
795795
batch_size=batch_size,
@@ -924,7 +924,7 @@ def __call__(
924924
for idx in range(1, len(frames)):
925925
image = frames[idx]
926926
prev_image = frames[idx - 1]
927-
control_image = control_frames[idx]
927+
control_image = control_frames(image) if callable(control_frames) else control_frames[idx]
928928
# 5.1 prepare frames
929929
image = self.image_processor.preprocess(image).to(dtype=self.dtype)
930930
prev_image = self.image_processor.preprocess(prev_image).to(dtype=self.dtype)

src/diffusers/loaders/lora_base.py

Lines changed: 156 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,20 @@
2828
from ..utils import (
2929
USE_PEFT_BACKEND,
3030
_get_model_file,
31+
convert_state_dict_to_diffusers,
32+
convert_state_dict_to_peft,
3133
delete_adapter_layers,
3234
deprecate,
35+
get_adapter_name,
36+
get_peft_kwargs,
3337
is_accelerate_available,
3438
is_peft_available,
39+
is_peft_version,
3540
is_transformers_available,
41+
is_transformers_version,
3642
logging,
3743
recurse_remove_peft_layers,
44+
scale_lora_layers,
3845
set_adapter_layers,
3946
set_weights_and_activate_adapters,
4047
)
@@ -43,6 +50,8 @@
4350
if is_transformers_available():
4451
from transformers import PreTrainedModel
4552

53+
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
54+
4655
if is_peft_available():
4756
from peft.tuners.tuners_utils import BaseTunerLayer
4857

@@ -297,6 +306,152 @@ def _best_guess_weight_name(
297306
return weight_name
298307

299308

309+
def _load_lora_into_text_encoder(
310+
state_dict,
311+
network_alphas,
312+
text_encoder,
313+
prefix=None,
314+
lora_scale=1.0,
315+
text_encoder_name="text_encoder",
316+
adapter_name=None,
317+
_pipeline=None,
318+
low_cpu_mem_usage=False,
319+
):
320+
if not USE_PEFT_BACKEND:
321+
raise ValueError("PEFT backend is required for this method.")
322+
323+
peft_kwargs = {}
324+
if low_cpu_mem_usage:
325+
if not is_peft_version(">=", "0.13.1"):
326+
raise ValueError(
327+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
328+
)
329+
if not is_transformers_version(">", "4.45.2"):
330+
# Note from sayakpaul: It's not in `transformers` stable yet.
331+
# https://github.com/huggingface/transformers/pull/33725/
332+
raise ValueError(
333+
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
334+
)
335+
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
336+
337+
from peft import LoraConfig
338+
339+
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
340+
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
341+
# their prefixes.
342+
keys = list(state_dict.keys())
343+
prefix = text_encoder_name if prefix is None else prefix
344+
345+
# Safe prefix to check with.
346+
if any(text_encoder_name in key for key in keys):
347+
# Load the layers corresponding to text encoder and make necessary adjustments.
348+
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
349+
text_encoder_lora_state_dict = {
350+
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
351+
}
352+
353+
if len(text_encoder_lora_state_dict) > 0:
354+
logger.info(f"Loading {prefix}.")
355+
rank = {}
356+
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
357+
358+
# convert state dict
359+
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
360+
361+
for name, _ in text_encoder_attn_modules(text_encoder):
362+
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
363+
rank_key = f"{name}.{module}.lora_B.weight"
364+
if rank_key not in text_encoder_lora_state_dict:
365+
continue
366+
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
367+
368+
for name, _ in text_encoder_mlp_modules(text_encoder):
369+
for module in ("fc1", "fc2"):
370+
rank_key = f"{name}.{module}.lora_B.weight"
371+
if rank_key not in text_encoder_lora_state_dict:
372+
continue
373+
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
374+
375+
if network_alphas is not None:
376+
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
377+
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
378+
379+
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
380+
381+
if "use_dora" in lora_config_kwargs:
382+
if lora_config_kwargs["use_dora"]:
383+
if is_peft_version("<", "0.9.0"):
384+
raise ValueError(
385+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
386+
)
387+
else:
388+
if is_peft_version("<", "0.9.0"):
389+
lora_config_kwargs.pop("use_dora")
390+
391+
if "lora_bias" in lora_config_kwargs:
392+
if lora_config_kwargs["lora_bias"]:
393+
if is_peft_version("<=", "0.13.2"):
394+
raise ValueError(
395+
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
396+
)
397+
else:
398+
if is_peft_version("<=", "0.13.2"):
399+
lora_config_kwargs.pop("lora_bias")
400+
401+
lora_config = LoraConfig(**lora_config_kwargs)
402+
403+
# adapter_name
404+
if adapter_name is None:
405+
adapter_name = get_adapter_name(text_encoder)
406+
407+
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
408+
409+
# inject LoRA layers and load the state dict
410+
# in transformers we automatically check whether the adapter name is already in use or not
411+
text_encoder.load_adapter(
412+
adapter_name=adapter_name,
413+
adapter_state_dict=text_encoder_lora_state_dict,
414+
peft_config=lora_config,
415+
**peft_kwargs,
416+
)
417+
418+
# scale LoRA layers with `lora_scale`
419+
scale_lora_layers(text_encoder, weight=lora_scale)
420+
421+
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
422+
423+
# Offload back.
424+
if is_model_cpu_offload:
425+
_pipeline.enable_model_cpu_offload()
426+
elif is_sequential_cpu_offload:
427+
_pipeline.enable_sequential_cpu_offload()
428+
# Unsafe code />
429+
430+
431+
def _func_optionally_disable_offloading(_pipeline):
432+
is_model_cpu_offload = False
433+
is_sequential_cpu_offload = False
434+
435+
if _pipeline is not None and _pipeline.hf_device_map is None:
436+
for _, component in _pipeline.components.items():
437+
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
438+
if not is_model_cpu_offload:
439+
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
440+
if not is_sequential_cpu_offload:
441+
is_sequential_cpu_offload = (
442+
isinstance(component._hf_hook, AlignDevicesHook)
443+
or hasattr(component._hf_hook, "hooks")
444+
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
445+
)
446+
447+
logger.info(
448+
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
449+
)
450+
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
451+
452+
return (is_model_cpu_offload, is_sequential_cpu_offload)
453+
454+
300455
class LoraBaseMixin:
301456
"""Utility class for handling LoRAs."""
302457

@@ -327,27 +482,7 @@ def _optionally_disable_offloading(cls, _pipeline):
327482
tuple:
328483
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
329484
"""
330-
is_model_cpu_offload = False
331-
is_sequential_cpu_offload = False
332-
333-
if _pipeline is not None and _pipeline.hf_device_map is None:
334-
for _, component in _pipeline.components.items():
335-
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
336-
if not is_model_cpu_offload:
337-
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
338-
if not is_sequential_cpu_offload:
339-
is_sequential_cpu_offload = (
340-
isinstance(component._hf_hook, AlignDevicesHook)
341-
or hasattr(component._hf_hook, "hooks")
342-
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
343-
)
344-
345-
logger.info(
346-
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
347-
)
348-
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
349-
350-
return (is_model_cpu_offload, is_sequential_cpu_offload)
485+
return _func_optionally_disable_offloading(_pipeline=_pipeline)
351486

352487
@classmethod
353488
def _fetch_state_dict(cls, *args, **kwargs):

0 commit comments

Comments
 (0)