Skip to content

Commit eb1a3f4

Browse files
committed
Merge branch 'bria_3_2_pipeline' of https://github.com/galbria/diffusers into bria_3_2_pipeline
2 parents be29631 + 649767f commit eb1a3f4

File tree

20 files changed

+1566
-67
lines changed

20 files changed

+1566
-67
lines changed

docs/source/en/_toctree.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@
179179
isExpanded: false
180180
sections:
181181
- local: quantization/overview
182-
title: Getting Started
182+
title: Getting started
183183
- local: quantization/bitsandbytes
184184
title: bitsandbytes
185185
- local: quantization/gguf

docs/source/en/api/quantization.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,19 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
2727

2828
## BitsAndBytesConfig
2929

30-
[[autodoc]] BitsAndBytesConfig
30+
[[autodoc]] quantizers.quantization_config.BitsAndBytesConfig
3131

3232
## GGUFQuantizationConfig
3333

34-
[[autodoc]] GGUFQuantizationConfig
34+
[[autodoc]] quantizers.quantization_config.GGUFQuantizationConfig
3535

3636
## QuantoConfig
3737

38-
[[autodoc]] QuantoConfig
38+
[[autodoc]] quantizers.quantization_config.QuantoConfig
3939

4040
## TorchAoConfig
4141

42-
[[autodoc]] TorchAoConfig
42+
[[autodoc]] quantizers.quantization_config.TorchAoConfig
4343

4444
## DiffusersQuantizer
4545

docs/source/en/quantization/overview.md

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,33 @@ specific language governing permissions and limitations under the License.
1111
1212
-->
1313

14-
# Quantization
14+
# Getting started
1515

1616
Quantization focuses on representing data with fewer bits while also trying to preserve the precision of the original data. This often means converting a data type to represent the same information with fewer bits. For example, if your model weights are stored as 32-bit floating points and they're quantized to 16-bit floating points, this halves the model size which makes it easier to store and reduces memory usage. Lower precision can also speedup inference because it takes less time to perform calculations with fewer bits.
1717

1818
Diffusers supports multiple quantization backends to make large diffusion models like [Flux](../api/pipelines/flux) more accessible. This guide shows how to use the [`~quantizers.PipelineQuantizationConfig`] class to quantize a pipeline during its initialization from a pretrained or non-quantized checkpoint.
1919

2020
## Pipeline-level quantization
2121

22-
There are two ways you can use [`~quantizers.PipelineQuantizationConfig`] depending on the level of control you want over the quantization specifications of each model in the pipeline.
22+
There are two ways to use [`~quantizers.PipelineQuantizationConfig`] depending on how much customization you want to apply to the quantization configuration.
2323

24-
- for more basic and simple use cases, you only need to define the `quant_backend`, `quant_kwargs`, and `components_to_quantize`
25-
- for more granular quantization control, provide a `quant_mapping` that provides the quantization specifications for the individual model components
24+
- for basic use cases, define the `quant_backend`, `quant_kwargs`, and `components_to_quantize` arguments
25+
- for granular quantization control, define a `quant_mapping` that provides the quantization configuration for individual model components
2626

27-
### Simple quantization
27+
### Basic quantization
2828

2929
Initialize [`~quantizers.PipelineQuantizationConfig`] with the following parameters.
3030

3131
- `quant_backend` specifies which quantization backend to use. Currently supported backends include: `bitsandbytes_4bit`, `bitsandbytes_8bit`, `gguf`, `quanto`, and `torchao`.
32-
- `quant_kwargs` contains the specific quantization arguments to use.
32+
- `quant_kwargs` specifies the quantization arguments to use.
33+
34+
> [!TIP]
35+
> These `quant_kwargs` arguments are different for each backend. Refer to the [Quantization API](../api/quantization) docs to view the arguments for each backend.
36+
3337
- `components_to_quantize` specifies which components of the pipeline to quantize. Typically, you should quantize the most compute intensive components like the transformer. The text encoder is another component to consider quantizing if a pipeline has more than one such as [`FluxPipeline`]. The example below quantizes the T5 text encoder in [`FluxPipeline`] while keeping the CLIP model intact.
3438

39+
The example below loads the bitsandbytes backend with the following arguments from [`~quantizers.quantization_config.BitsAndBytesConfig`], `load_in_4bit`, `bnb_4bit_quant_type`, and `bnb_4bit_compute_dtype`.
40+
3541
```py
3642
import torch
3743
from diffusers import DiffusionPipeline
@@ -56,13 +62,13 @@ pipe = DiffusionPipeline.from_pretrained(
5662
image = pipe("photo of a cute dog").images[0]
5763
```
5864

59-
### quant_mapping
65+
### Advanced quantization
6066

61-
The `quant_mapping` argument provides more flexible options for how to quantize each individual component in a pipeline, like combining different quantization backends.
67+
The `quant_mapping` argument provides more options for how to quantize each individual component in a pipeline, like combining different quantization backends.
6268

6369
Initialize [`~quantizers.PipelineQuantizationConfig`] and pass a `quant_mapping` to it. The `quant_mapping` allows you to specify the quantization options for each component in the pipeline such as the transformer and text encoder.
6470

65-
The example below uses two quantization backends, [`~quantizers.QuantoConfig`] and [`transformers.BitsAndBytesConfig`], for the transformer and text encoder.
71+
The example below uses two quantization backends, [`~quantizers.quantization_config.QuantoConfig`] and [`transformers.BitsAndBytesConfig`], for the transformer and text encoder.
6672

6773
```py
6874
import torch
@@ -85,7 +91,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
8591
There is a separate bitsandbytes backend in [Transformers](https://huggingface.co/docs/transformers/main_classes/quantization#transformers.BitsAndBytesConfig). You need to import and use [`transformers.BitsAndBytesConfig`] for components that come from Transformers. For example, `text_encoder_2` in [`FluxPipeline`] is a [`~transformers.T5EncoderModel`] from Transformers so you need to use [`transformers.BitsAndBytesConfig`] instead of [`diffusers.BitsAndBytesConfig`].
8692

8793
> [!TIP]
88-
> Use the [simple quantization](#simple-quantization) method above if you don't want to manage these distinct imports or aren't sure where each pipeline component comes from.
94+
> Use the [basic quantization](#basic-quantization) method above if you don't want to manage these distinct imports or aren't sure where each pipeline component comes from.
8995
9096
```py
9197
import torch
@@ -129,4 +135,4 @@ Check out the resources below to learn more about quantization.
129135

130136
- The Transformers quantization [Overview](https://huggingface.co/docs/transformers/quantization/overview#when-to-use-what) provides an overview of the pros and cons of different quantization backends.
131137

132-
- Read the [Exploring Quantization Backends in Diffusers](https://huggingface.co/blog/diffusers-quantization) blog post for a brief introduction to each quantization backend, how to choose a backend, and combining quantization with other memory optimizations.
138+
- Read the [Exploring Quantization Backends in Diffusers](https://huggingface.co/blog/diffusers-quantization) blog post for a brief introduction to each quantization backend, how to choose a backend, and combining quantization with other memory optimizations.

docs/source/en/tutorials/using_peft_for_inference.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,19 @@ If you expect to varied resolutions during inference with this feature, then mak
319319

320320
There are still scenarios where recompulation is unavoidable, such as when the hotswapped LoRA targets more layers than the initial adapter. Try to load the LoRA that targets the most layers *first*. For more details about this limitation, refer to the PEFT [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) docs.
321321

322+
<details>
323+
<summary>Technical details of hotswapping</summary>
324+
325+
The [`~loaders.lora_base.LoraBaseMixin.enable_lora_hotswap`] method converts the LoRA scaling factor from floats to torch.tensors and pads the shape of the weights to the largest required shape to avoid reassigning the whole attribute when the data in the weights are replaced.
326+
327+
This is why the `max_rank` argument is important. The results are unchanged even when the values are padded with zeros. Computation may be slower though depending on the padding size.
328+
329+
Since no new LoRA attributes are added, each subsequent LoRA is only allowed to target the same layers, or subset of layers, the first LoRA targets. Choosing the LoRA loading order is important because if the LoRAs target disjoint layers, you may end up creating a dummy LoRA that targets the union of all target layers.
330+
331+
For more implementation details, take a look at the [`hotswap.py`](https://github.com/huggingface/peft/blob/92d65cafa51c829484ad3d95cf71d09de57ff066/src/peft/utils/hotswap.py) file.
332+
333+
</details>
334+
322335
## Merge
323336

324337
The weights from each LoRA can be merged together to produce a blend of multiple existing styles. There are several methods for merging LoRAs, each of which differ in *how* the weights are merged (may affect generation quality).
@@ -673,4 +686,6 @@ Browse the [LoRA Studio](https://lorastudio.co/models) for different LoRAs to us
673686
height="450"
674687
></iframe>
675688
676-
You can find additional LoRAs in the [FLUX LoRA the Explorer](https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer) and [LoRA the Explorer](https://huggingface.co/spaces/multimodalart/LoraTheExplorer) Spaces.
689+
You can find additional LoRAs in the [FLUX LoRA the Explorer](https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer) and [LoRA the Explorer](https://huggingface.co/spaces/multimodalart/LoraTheExplorer) Spaces.
690+
691+
Check out the [Fast LoRA inference for Flux with Diffusers and PEFT](https://huggingface.co/blog/lora-fast) blog post to learn how to optimize LoRA inference with methods like FlashAttention-3 and fp8 quantization.

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,8 @@
365365
else:
366366
_import_structure["modular_pipelines"].extend(
367367
[
368+
"FluxAutoBlocks",
369+
"FluxModularPipeline",
368370
"StableDiffusionXLAutoBlocks",
369371
"StableDiffusionXLModularPipeline",
370372
"WanAutoBlocks",
@@ -1002,6 +1004,8 @@
10021004
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
10031005
else:
10041006
from .modular_pipelines import (
1007+
FluxAutoBlocks,
1008+
FluxModularPipeline,
10051009
StableDiffusionXLAutoBlocks,
10061010
StableDiffusionXLModularPipeline,
10071011
WanAutoBlocks,

src/diffusers/hooks/_helpers.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def _register(cls):
107107
def _register_attention_processors_metadata():
108108
from ..models.attention_processor import AttnProcessor2_0
109109
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
110+
from ..models.transformers.transformer_flux import FluxAttnProcessor
110111
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
111112

112113
# AttnProcessor2_0
@@ -132,6 +133,11 @@ def _register_attention_processors_metadata():
132133
skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0,
133134
),
134135
)
136+
# FluxAttnProcessor
137+
AttentionProcessorRegistry.register(
138+
model_class=FluxAttnProcessor,
139+
metadata=AttentionProcessorMetadata(skip_processor_output_fn=_skip_proc_output_fn_Attention_FluxAttnProcessor),
140+
)
135141

136142

137143
def _register_transformer_blocks_metadata():
@@ -271,4 +277,6 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, *
271277
_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
272278
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
273279
_skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states
280+
# not sure what this is yet.
281+
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
274282
# fmt: on

src/diffusers/models/transformers/transformer_ltx.py

Lines changed: 103 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 The Genmo team and The HuggingFace Team.
1+
# Copyright 2025 The Lightricks team and The HuggingFace Team.
22
# All rights reserved.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,19 +13,19 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import inspect
1617
import math
1718
from typing import Any, Dict, Optional, Tuple, Union
1819

1920
import torch
2021
import torch.nn as nn
21-
import torch.nn.functional as F
2222

2323
from ...configuration_utils import ConfigMixin, register_to_config
2424
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
25-
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
25+
from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
2626
from ...utils.torch_utils import maybe_allow_in_graph
27-
from ..attention import FeedForward
28-
from ..attention_processor import Attention
27+
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
28+
from ..attention_dispatch import dispatch_attention_fn
2929
from ..cache_utils import CacheMixin
3030
from ..embeddings import PixArtAlphaTextProjection
3131
from ..modeling_outputs import Transformer2DModelOutput
@@ -37,20 +37,30 @@
3737

3838

3939
class LTXVideoAttentionProcessor2_0:
40+
def __new__(cls, *args, **kwargs):
41+
deprecation_message = "`LTXVideoAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `LTXVideoAttnProcessor`"
42+
deprecate("LTXVideoAttentionProcessor2_0", "1.0.0", deprecation_message)
43+
44+
return LTXVideoAttnProcessor(*args, **kwargs)
45+
46+
47+
class LTXVideoAttnProcessor:
4048
r"""
41-
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
42-
used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
49+
Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0). This is used in the LTX
50+
model. It applies a normalization layer and rotary embedding on the query and key vector.
4351
"""
4452

53+
_attention_backend = None
54+
4555
def __init__(self):
46-
if not hasattr(F, "scaled_dot_product_attention"):
47-
raise ImportError(
48-
"LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
56+
if is_torch_version("<", "2.0"):
57+
raise ValueError(
58+
"LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
4959
)
5060

5161
def __call__(
5262
self,
53-
attn: Attention,
63+
attn: "LTXAttention",
5464
hidden_states: torch.Tensor,
5565
encoder_hidden_states: Optional[torch.Tensor] = None,
5666
attention_mask: Optional[torch.Tensor] = None,
@@ -78,21 +88,91 @@ def __call__(
7888
query = apply_rotary_emb(query, image_rotary_emb)
7989
key = apply_rotary_emb(key, image_rotary_emb)
8090

81-
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
82-
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
83-
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
84-
85-
hidden_states = F.scaled_dot_product_attention(
86-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
91+
query = query.unflatten(2, (attn.heads, -1))
92+
key = key.unflatten(2, (attn.heads, -1))
93+
value = value.unflatten(2, (attn.heads, -1))
94+
95+
hidden_states = dispatch_attention_fn(
96+
query,
97+
key,
98+
value,
99+
attn_mask=attention_mask,
100+
dropout_p=0.0,
101+
is_causal=False,
102+
backend=self._attention_backend,
87103
)
88-
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
104+
hidden_states = hidden_states.flatten(2, 3)
89105
hidden_states = hidden_states.to(query.dtype)
90106

91107
hidden_states = attn.to_out[0](hidden_states)
92108
hidden_states = attn.to_out[1](hidden_states)
93109
return hidden_states
94110

95111

112+
class LTXAttention(torch.nn.Module, AttentionModuleMixin):
113+
_default_processor_cls = LTXVideoAttnProcessor
114+
_available_processors = [LTXVideoAttnProcessor]
115+
116+
def __init__(
117+
self,
118+
query_dim: int,
119+
heads: int = 8,
120+
kv_heads: int = 8,
121+
dim_head: int = 64,
122+
dropout: float = 0.0,
123+
bias: bool = True,
124+
cross_attention_dim: Optional[int] = None,
125+
out_bias: bool = True,
126+
qk_norm: str = "rms_norm_across_heads",
127+
processor=None,
128+
):
129+
super().__init__()
130+
if qk_norm != "rms_norm_across_heads":
131+
raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.")
132+
133+
self.head_dim = dim_head
134+
self.inner_dim = dim_head * heads
135+
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
136+
self.query_dim = query_dim
137+
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
138+
self.use_bias = bias
139+
self.dropout = dropout
140+
self.out_dim = query_dim
141+
self.heads = heads
142+
143+
norm_eps = 1e-5
144+
norm_elementwise_affine = True
145+
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
146+
self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
147+
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
148+
self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
149+
self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
150+
self.to_out = torch.nn.ModuleList([])
151+
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
152+
self.to_out.append(torch.nn.Dropout(dropout))
153+
154+
if processor is None:
155+
processor = self._default_processor_cls()
156+
self.set_processor(processor)
157+
158+
def forward(
159+
self,
160+
hidden_states: torch.Tensor,
161+
encoder_hidden_states: Optional[torch.Tensor] = None,
162+
attention_mask: Optional[torch.Tensor] = None,
163+
image_rotary_emb: Optional[torch.Tensor] = None,
164+
**kwargs,
165+
) -> torch.Tensor:
166+
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
167+
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
168+
if len(unused_kwargs) > 0:
169+
logger.warning(
170+
f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
171+
)
172+
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
173+
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
174+
175+
96176
class LTXVideoRotaryPosEmbed(nn.Module):
97177
def __init__(
98178
self,
@@ -231,7 +311,7 @@ def __init__(
231311
super().__init__()
232312

233313
self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
234-
self.attn1 = Attention(
314+
self.attn1 = LTXAttention(
235315
query_dim=dim,
236316
heads=num_attention_heads,
237317
kv_heads=num_attention_heads,
@@ -240,11 +320,10 @@ def __init__(
240320
cross_attention_dim=None,
241321
out_bias=attention_out_bias,
242322
qk_norm=qk_norm,
243-
processor=LTXVideoAttentionProcessor2_0(),
244323
)
245324

246325
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
247-
self.attn2 = Attention(
326+
self.attn2 = LTXAttention(
248327
query_dim=dim,
249328
cross_attention_dim=cross_attention_dim,
250329
heads=num_attention_heads,
@@ -253,7 +332,6 @@ def __init__(
253332
bias=attention_bias,
254333
out_bias=attention_out_bias,
255334
qk_norm=qk_norm,
256-
processor=LTXVideoAttentionProcessor2_0(),
257335
)
258336

259337
self.ff = FeedForward(dim, activation_fn=activation_fn)
@@ -299,7 +377,9 @@ def forward(
299377

300378

301379
@maybe_allow_in_graph
302-
class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin):
380+
class LTXVideoTransformer3DModel(
381+
ModelMixin, ConfigMixin, AttentionMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin
382+
):
303383
r"""
304384
A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
305385

0 commit comments

Comments
 (0)