Skip to content

Commit 9acadf3

Browse files
committed
Merge branch 'to-single-file/flux' into to-single-file/wan
2 parents da9bfb1 + bc64f12 commit 9acadf3

File tree

9 files changed

+75
-252
lines changed

9 files changed

+75
-252
lines changed

examples/controlnet/train_controlnet_sd3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1330,7 +1330,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
13301330
# controlnet(s) inference
13311331
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
13321332
controlnet_image = vae.encode(controlnet_image).latent_dist.sample()
1333-
controlnet_image = controlnet_image * vae.config.scaling_factor
1333+
controlnet_image = (controlnet_image - vae.config.shift_factor) * vae.config.scaling_factor
13341334

13351335
control_block_res_samples = controlnet(
13361336
hidden_states=noisy_model_input,

examples/server/requirements.txt

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# This file was autogenerated by uv via the following command:
22
# uv pip compile requirements.in -o requirements.txt
3-
aiohappyeyeballs==2.4.3
3+
aiohappyeyeballs==2.6.1
44
# via aiohttp
5-
aiohttp==3.10.10
5+
aiohttp==3.12.14
66
# via -r requirements.in
7-
aiosignal==1.3.1
7+
aiosignal==1.4.0
88
# via aiohttp
99
annotated-types==0.7.0
1010
# via pydantic
@@ -29,7 +29,6 @@ filelock==3.16.1
2929
# huggingface-hub
3030
# torch
3131
# transformers
32-
# triton
3332
frozenlist==1.5.0
3433
# via
3534
# aiohttp
@@ -111,7 +110,9 @@ prometheus-client==0.21.0
111110
prometheus-fastapi-instrumentator==7.0.0
112111
# via -r requirements.in
113112
propcache==0.2.0
114-
# via yarl
113+
# via
114+
# aiohttp
115+
# yarl
115116
py-consul==1.5.3
116117
# via -r requirements.in
117118
pydantic==2.9.2
@@ -155,7 +156,9 @@ triton==3.3.0
155156
# via torch
156157
typing-extensions==4.12.2
157158
# via
159+
# aiosignal
158160
# anyio
161+
# exceptiongroup
159162
# fastapi
160163
# huggingface-hub
161164
# multidict
@@ -168,5 +171,5 @@ urllib3==2.5.0
168171
# via requests
169172
uvicorn==0.32.0
170173
# via -r requirements.in
171-
yarl==1.16.0
174+
yarl==1.18.3
172175
# via aiohttp

src/diffusers/hooks/faster_cache.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import torch
2020

21+
from ..models.attention import AttentionModuleMixin
2122
from ..models.attention_processor import Attention, MochiAttention
2223
from ..models.modeling_outputs import Transformer2DModelOutput
2324
from ..utils import logging
@@ -567,7 +568,7 @@ def high_frequency_weight_callback(module: torch.nn.Module) -> float:
567568
_apply_faster_cache_on_denoiser(module, config)
568569

569570
for name, submodule in module.named_modules():
570-
if not isinstance(submodule, _ATTENTION_CLASSES):
571+
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
571572
continue
572573
if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS):
573574
_apply_faster_cache_on_attention_class(name, submodule, config)

src/diffusers/loaders/ip_adapter.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@
4040
from ..models.attention_processor import (
4141
AttnProcessor,
4242
AttnProcessor2_0,
43-
FluxAttnProcessor2_0,
44-
FluxIPAdapterJointAttnProcessor2_0,
4543
IPAdapterAttnProcessor,
4644
IPAdapterAttnProcessor2_0,
4745
IPAdapterXFormersAttnProcessor,
@@ -867,6 +865,9 @@ def unload_ip_adapter(self):
867865
>>> ...
868866
```
869867
"""
868+
# TODO: once the 1.0.0 deprecations are in, we can move the imports to top-level
869+
from ..models.transformers.transformer_flux import FluxAttnProcessor, FluxIPAdapterAttnProcessor
870+
870871
# remove CLIP image encoder
871872
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
872873
self.image_encoder = None
@@ -886,9 +887,9 @@ def unload_ip_adapter(self):
886887
# restore original Transformer attention processors layers
887888
attn_procs = {}
888889
for name, value in self.transformer.attn_processors.items():
889-
attn_processor_class = FluxAttnProcessor2_0()
890+
attn_processor_class = FluxAttnProcessor()
890891
attn_procs[name] = (
891-
attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
892+
attn_processor_class if isinstance(value, FluxIPAdapterAttnProcessor) else value.__class__()
892893
)
893894
self.transformer.set_attn_processor(attn_procs)
894895

src/diffusers/loaders/transformer_flux.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
8787
return image_projection
8888

8989
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
90-
from ..models.attention_processor import (
91-
FluxIPAdapterJointAttnProcessor2_0,
92-
)
90+
from ..models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
9391

9492
if low_cpu_mem_usage:
9593
if is_accelerate_available():
@@ -121,7 +119,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_
121119
else:
122120
cross_attention_dim = self.config.joint_attention_dim
123121
hidden_size = self.inner_dim
124-
attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
122+
attn_processor_class = FluxIPAdapterAttnProcessor
125123
num_image_text_embeds = []
126124
for state_dict in state_dicts:
127125
if "proj.weight" in state_dict["image_proj"]:

src/diffusers/models/attention_processor.py

Lines changed: 10 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -2501,152 +2501,6 @@ def __call__(
25012501
return hidden_states
25022502

25032503

2504-
class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
2505-
"""Flux Attention processor for IP-Adapter."""
2506-
2507-
def __init__(
2508-
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
2509-
):
2510-
super().__init__()
2511-
2512-
if not hasattr(F, "scaled_dot_product_attention"):
2513-
raise ImportError(
2514-
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
2515-
)
2516-
2517-
self.hidden_size = hidden_size
2518-
self.cross_attention_dim = cross_attention_dim
2519-
2520-
if not isinstance(num_tokens, (tuple, list)):
2521-
num_tokens = [num_tokens]
2522-
2523-
if not isinstance(scale, list):
2524-
scale = [scale] * len(num_tokens)
2525-
if len(scale) != len(num_tokens):
2526-
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
2527-
self.scale = scale
2528-
2529-
self.to_k_ip = nn.ModuleList(
2530-
[
2531-
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
2532-
for _ in range(len(num_tokens))
2533-
]
2534-
)
2535-
self.to_v_ip = nn.ModuleList(
2536-
[
2537-
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
2538-
for _ in range(len(num_tokens))
2539-
]
2540-
)
2541-
2542-
def __call__(
2543-
self,
2544-
attn: Attention,
2545-
hidden_states: torch.FloatTensor,
2546-
encoder_hidden_states: torch.FloatTensor = None,
2547-
attention_mask: Optional[torch.FloatTensor] = None,
2548-
image_rotary_emb: Optional[torch.Tensor] = None,
2549-
ip_hidden_states: Optional[List[torch.Tensor]] = None,
2550-
ip_adapter_masks: Optional[torch.Tensor] = None,
2551-
) -> torch.FloatTensor:
2552-
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2553-
2554-
# `sample` projections.
2555-
hidden_states_query_proj = attn.to_q(hidden_states)
2556-
key = attn.to_k(hidden_states)
2557-
value = attn.to_v(hidden_states)
2558-
2559-
inner_dim = key.shape[-1]
2560-
head_dim = inner_dim // attn.heads
2561-
2562-
hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2563-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2564-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2565-
2566-
if attn.norm_q is not None:
2567-
hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
2568-
if attn.norm_k is not None:
2569-
key = attn.norm_k(key)
2570-
2571-
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
2572-
if encoder_hidden_states is not None:
2573-
# `context` projections.
2574-
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
2575-
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
2576-
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
2577-
2578-
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
2579-
batch_size, -1, attn.heads, head_dim
2580-
).transpose(1, 2)
2581-
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
2582-
batch_size, -1, attn.heads, head_dim
2583-
).transpose(1, 2)
2584-
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
2585-
batch_size, -1, attn.heads, head_dim
2586-
).transpose(1, 2)
2587-
2588-
if attn.norm_added_q is not None:
2589-
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
2590-
if attn.norm_added_k is not None:
2591-
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
2592-
2593-
# attention
2594-
query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
2595-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
2596-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
2597-
2598-
if image_rotary_emb is not None:
2599-
from .embeddings import apply_rotary_emb
2600-
2601-
query = apply_rotary_emb(query, image_rotary_emb)
2602-
key = apply_rotary_emb(key, image_rotary_emb)
2603-
2604-
hidden_states = F.scaled_dot_product_attention(
2605-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2606-
)
2607-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2608-
hidden_states = hidden_states.to(query.dtype)
2609-
2610-
if encoder_hidden_states is not None:
2611-
encoder_hidden_states, hidden_states = (
2612-
hidden_states[:, : encoder_hidden_states.shape[1]],
2613-
hidden_states[:, encoder_hidden_states.shape[1] :],
2614-
)
2615-
2616-
# linear proj
2617-
hidden_states = attn.to_out[0](hidden_states)
2618-
# dropout
2619-
hidden_states = attn.to_out[1](hidden_states)
2620-
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
2621-
2622-
# IP-adapter
2623-
ip_query = hidden_states_query_proj
2624-
ip_attn_output = torch.zeros_like(hidden_states)
2625-
2626-
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
2627-
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
2628-
):
2629-
ip_key = to_k_ip(current_ip_hidden_states)
2630-
ip_value = to_v_ip(current_ip_hidden_states)
2631-
2632-
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2633-
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2634-
# the output of sdp = (batch, num_heads, seq_len, head_dim)
2635-
# TODO: add support for attn.scale when we move to Torch 2.1
2636-
current_ip_hidden_states = F.scaled_dot_product_attention(
2637-
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
2638-
)
2639-
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
2640-
batch_size, -1, attn.heads * head_dim
2641-
)
2642-
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
2643-
ip_attn_output += scale * current_ip_hidden_states
2644-
2645-
return hidden_states, encoder_hidden_states, ip_attn_output
2646-
else:
2647-
return hidden_states
2648-
2649-
26502504
class CogVideoXAttnProcessor2_0:
26512505
r"""
26522506
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
@@ -6019,6 +5873,16 @@ def __new__(cls, *args, **kwargs):
60195873
return FluxAttnProcessor(*args, **kwargs)
60205874

60215875

5876+
class FluxIPAdapterJointAttnProcessor2_0:
5877+
def __new__(cls, *args, **kwargs):
5878+
deprecation_message = "`FluxIPAdapterJointAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxIPAdapterAttnProcessor`"
5879+
deprecate("FluxIPAdapterJointAttnProcessor2_0", "1.0.0", deprecation_message)
5880+
5881+
from .transformers.transformer_flux import FluxIPAdapterAttnProcessor
5882+
5883+
return FluxIPAdapterAttnProcessor(*args, **kwargs)
5884+
5885+
60225886
ADDED_KV_ATTENTION_PROCESSORS = (
60235887
AttnAddedKVProcessor,
60245888
SlicedAttnAddedKVProcessor,

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
15+
import inspect
1616
from typing import Any, Dict, List, Optional, Tuple, Union
1717

1818
import numpy as np
@@ -241,7 +241,9 @@ def __call__(
241241
query = apply_rotary_emb(query, image_rotary_emb)
242242
key = apply_rotary_emb(key, image_rotary_emb)
243243

244-
hidden_states = torch.nn.functional(query, key, value, dropout_p=0.0, is_causal=False)
244+
hidden_states = torch.nn.functional.scaled_dot_product_attention(
245+
query, key, value, dropout_p=0.0, is_causal=False
246+
)
245247
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
246248
hidden_states = hidden_states.to(query.dtype)
247249

@@ -354,6 +356,14 @@ def forward(
354356
image_rotary_emb: Optional[torch.Tensor] = None,
355357
**kwargs,
356358
) -> torch.Tensor:
359+
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
360+
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
361+
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
362+
if len(unused_kwargs) > 0:
363+
logger.warning(
364+
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
365+
)
366+
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
357367
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
358368

359369

tests/pipelines/chroma/test_pipeline_chroma.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,7 @@
77
from diffusers import AutoencoderKL, ChromaPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
88
from diffusers.utils.testing_utils import torch_device
99

10-
from ..test_pipelines_common import (
11-
FluxIPAdapterTesterMixin,
12-
PipelineTesterMixin,
13-
check_qkv_fusion_matches_attn_procs_length,
14-
check_qkv_fusion_processors_exist,
15-
)
10+
from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist
1611

1712

1813
class ChromaPipelineFastTests(
@@ -126,12 +121,10 @@ def test_fused_qkv_projections(self):
126121
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
127122
# to the pipeline level.
128123
pipe.transformer.fuse_qkv_projections()
129-
assert check_qkv_fusion_processors_exist(pipe.transformer), (
130-
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
124+
self.assertTrue(
125+
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
126+
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
131127
)
132-
assert check_qkv_fusion_matches_attn_procs_length(
133-
pipe.transformer, pipe.transformer.original_attn_processors
134-
), "Something wrong with the attention processors concerning the fused QKV projections."
135128

136129
inputs = self.get_dummy_inputs(device)
137130
image = pipe(**inputs).images

0 commit comments

Comments
 (0)