Skip to content

Commit 941911c

Browse files
committed
update
1 parent c87575d commit 941911c

File tree

7 files changed

+244
-62
lines changed

7 files changed

+244
-62
lines changed

src/diffusers/models/attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
import inspect
1416
from typing import Callable, Dict, Optional, Tuple, Union
1517

1618
import torch

src/diffusers/models/embeddings.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1244,31 +1244,21 @@ class FluxPosEmbed(nn.Module):
12441244
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
12451245
def __init__(self, theta: int, axes_dim: List[int]):
12461246
super().__init__()
1247+
1248+
from .transformers.transformer_flux import FluxPosEmbed as FluxPosEmbed_
1249+
1250+
deprecate(
1251+
"FluxPosEmbed",
1252+
"1.0.0",
1253+
"Importing and using `FluxPosEmbed` from `diffusers.models.embeddings` is deprecated. Please use `FluxPosEmbed` from `diffusers.models.transformers.transformer_flux` instead.",
1254+
)
1255+
12471256
self.theta = theta
12481257
self.axes_dim = axes_dim
1258+
self._rope = FluxPosEmbed_(theta, axes_dim)
12491259

12501260
def forward(self, ids: torch.Tensor) -> torch.Tensor:
1251-
n_axes = ids.shape[-1]
1252-
cos_out = []
1253-
sin_out = []
1254-
pos = ids.float()
1255-
is_mps = ids.device.type == "mps"
1256-
is_npu = ids.device.type == "npu"
1257-
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
1258-
for i in range(n_axes):
1259-
cos, sin = get_1d_rotary_pos_embed(
1260-
self.axes_dim[i],
1261-
pos[:, i],
1262-
theta=self.theta,
1263-
repeat_interleave_real=True,
1264-
use_real=True,
1265-
freqs_dtype=freqs_dtype,
1266-
)
1267-
cos_out.append(cos)
1268-
sin_out.append(sin)
1269-
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
1270-
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
1271-
return freqs_cos, freqs_sin
1261+
return self._rope(ids)
12721262

12731263

12741264
class TimestepEmbedding(nn.Module):

src/diffusers/models/modeling_utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,50 @@ def enable_group_offload(
599599
low_cpu_mem_usage=low_cpu_mem_usage,
600600
)
601601

602+
def set_attention_backend(self, backend: str) -> None:
603+
"""
604+
Set the attention backend for the model.
605+
606+
Args:
607+
backend (`str`):
608+
The name of the backend to set. Must be one of the available backends defined in
609+
`AttentionBackendName`. Available backends can be found in
610+
`diffusers.attention_dispatch.AttentionBackendName`. Defaults to torch native scaled dot product
611+
attention as backend.
612+
"""
613+
from .attention import AttentionModuleMixin
614+
from .attention_dispatch import AttentionBackendName
615+
616+
backend = backend.lower()
617+
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
618+
if backend not in available_backends:
619+
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
620+
backend = AttentionBackendName(backend)
621+
622+
for module in self.modules():
623+
if not isinstance(module, AttentionModuleMixin):
624+
continue
625+
processor = module.processor
626+
if processor is None or not hasattr(processor, "_attention_backend"):
627+
continue
628+
processor._attention_backend = backend
629+
630+
def reset_attention_backend(self) -> None:
631+
"""
632+
Resets the attention backend for the model. Following calls to `forward` will use the environment default or
633+
the torch native scaled dot product attention.
634+
"""
635+
from .attention_processor import Attention, MochiAttention
636+
637+
attention_classes = (Attention, MochiAttention)
638+
for module in self.modules():
639+
if not isinstance(module, attention_classes):
640+
continue
641+
processor = module.processor
642+
if processor is None or not hasattr(processor, "_attention_backend"):
643+
continue
644+
processor._attention_backend = None
645+
602646
def save_pretrained(
603647
self,
604648
save_directory: Union[str, os.PathLike],

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 119 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,15 @@
2424
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
2525
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
2626
from ...utils.torch_utils import maybe_allow_in_graph
27-
from ..attention import Attention, AttentionMixin, FeedForward
27+
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
2828
from ..attention_dispatch import dispatch_attention_fn
2929
from ..cache_utils import CacheMixin
30-
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
30+
from ..embeddings import (
31+
CombinedTimestepGuidanceTextProjEmbeddings,
32+
CombinedTimestepTextProjEmbeddings,
33+
apply_rotary_emb,
34+
get_1d_rotary_pos_embed,
35+
)
3136
from ..modeling_outputs import Transformer2DModelOutput
3237
from ..modeling_utils import ModelMixin
3338
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
@@ -73,7 +78,6 @@ def get_qkv_projections(self, attn, hidden_states, encoder_hidden_states=None):
7378
"""Public method to get projections based on whether we're using fused mode or not."""
7479
if attn.is_fused and hasattr(attn, "to_qkv"):
7580
return self._get_fused_projections(attn, hidden_states, encoder_hidden_states)
76-
7781
return self._get_projections(attn, hidden_states, encoder_hidden_states)
7882

7983
def __call__(
@@ -117,17 +121,10 @@ def __call__(
117121
value = torch.cat([encoder_value, value], dim=2)
118122

119123
if image_rotary_emb is not None:
120-
from ..embeddings import apply_rotary_emb
121-
122124
query = apply_rotary_emb(query, image_rotary_emb)
123125
key = apply_rotary_emb(key, image_rotary_emb)
124126

125-
hidden_states = dispatch_attention_fn(
126-
query,
127-
key,
128-
value,
129-
attn_mask=attention_mask,
130-
)
127+
hidden_states = dispatch_attention_fn(query, key, value, attn_mask=attention_mask)
131128

132129
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
133130
hidden_states = hidden_states.to(query.dtype)
@@ -242,12 +239,10 @@ def __call__(
242239
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
243240

244241
if image_rotary_emb is not None:
245-
from .embeddings import apply_rotary_emb
246-
247242
query = apply_rotary_emb(query, image_rotary_emb)
248243
key = apply_rotary_emb(key, image_rotary_emb)
249244

250-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
245+
hidden_states = dispatch_attention_fn(query, key, value, dropout_p=0.0, is_causal=False)
251246
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
252247
hidden_states = hidden_states.to(query.dtype)
253248

@@ -292,13 +287,76 @@ def __call__(
292287

293288

294289
@maybe_allow_in_graph
295-
class FluxAttention(Attention):
290+
class FluxAttention(torch.nn.Module, AttentionModuleMixin):
296291
_default_processor_cls = FluxAttnProcessor
297292
_available_processors = [
298293
FluxAttnProcessor,
299294
FluxIPAdapterAttnProcessor,
300295
]
301296

297+
def __init__(
298+
self,
299+
query_dim: int,
300+
heads: int = 8,
301+
dim_head: int = 64,
302+
dropout: float = 0.0,
303+
bias: bool = False,
304+
qk_norm: Optional[str] = None,
305+
added_kv_proj_dim: Optional[int] = None,
306+
added_proj_bias: Optional[bool] = True,
307+
out_bias: bool = True,
308+
eps: float = 1e-5,
309+
out_dim: int = None,
310+
context_pre_only: Optional[bool] = None,
311+
pre_only: bool = False,
312+
elementwise_affine: bool = True,
313+
processor=None,
314+
):
315+
super().__init__()
316+
assert qk_norm == "rms_norm", "Flux uses RMSNorm"
317+
318+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
319+
self.query_dim = query_dim
320+
self.use_bias = bias
321+
self.dropout = dropout
322+
self.out_dim = out_dim if out_dim is not None else query_dim
323+
self.context_pre_only = context_pre_only
324+
self.pre_only = pre_only
325+
self.heads = out_dim // dim_head if out_dim is not None else heads
326+
self.added_proj_bias = added_proj_bias
327+
328+
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
329+
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
330+
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
331+
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
332+
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
333+
334+
if not self.pre_only:
335+
self.to_out = torch.nn.ModuleList([])
336+
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
337+
338+
if added_kv_proj_dim is not None:
339+
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
340+
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
341+
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
342+
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
343+
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
344+
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
345+
346+
if processor is None:
347+
processor = self._default_processor_cls()
348+
self.set_processor(processor)
349+
350+
def forward(
351+
self,
352+
hidden_states: torch.Tensor,
353+
encoder_hidden_states: Optional[torch.Tensor] = None,
354+
attention_mask: Optional[torch.Tensor] = None,
355+
image_rotary_emb: Optional[torch.Tensor] = None,
356+
**kwargs,
357+
) -> torch.Tensor:
358+
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
359+
302360

303361
@maybe_allow_in_graph
304362
class FluxSingleTransformerBlock(nn.Module):
@@ -330,20 +388,19 @@ def forward(
330388
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
331389
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
332390
) -> torch.Tensor:
333-
residual = hidden_states
391+
joint_attention_kwargs = joint_attention_kwargs or {}
392+
334393
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
335394
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
336-
joint_attention_kwargs = joint_attention_kwargs or {}
337395
attn_output = self.attn(
338396
hidden_states=norm_hidden_states,
339397
image_rotary_emb=image_rotary_emb,
340398
**joint_attention_kwargs,
341399
)
400+
attn_mlp_hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
401+
proj_out = self.proj_out(attn_mlp_hidden_states)
402+
hidden_states = hidden_states + gate.unsqueeze(1) * proj_out
342403

343-
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
344-
gate = gate.unsqueeze(1)
345-
hidden_states = gate * self.proj_out(hidden_states)
346-
hidden_states = residual + hidden_states
347404
if hidden_states.dtype == torch.float16:
348405
hidden_states = hidden_states.clip(-65504, 65504)
349406

@@ -386,12 +443,13 @@ def forward(
386443
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
387444
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
388445
) -> Tuple[torch.Tensor, torch.Tensor]:
389-
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
446+
joint_attention_kwargs = joint_attention_kwargs or {}
390447

448+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
391449
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
392450
encoder_hidden_states, emb=temb
393451
)
394-
joint_attention_kwargs = joint_attention_kwargs or {}
452+
395453
# Attention.
396454
attention_outputs = self.attn(
397455
hidden_states=norm_hidden_states,
@@ -410,7 +468,7 @@ def forward(
410468
hidden_states = hidden_states + attn_output
411469

412470
norm_hidden_states = self.norm2(hidden_states)
413-
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
471+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
414472

415473
ff_output = self.ff(norm_hidden_states)
416474
ff_output = gate_mlp.unsqueeze(1) * ff_output
@@ -420,21 +478,54 @@ def forward(
420478
hidden_states = hidden_states + ip_attn_output
421479

422480
# Process attention outputs for the `encoder_hidden_states`.
423-
424481
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
425482
encoder_hidden_states = encoder_hidden_states + context_attn_output
426483

427484
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
428-
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
485+
norm_encoder_hidden_states = norm_encoder_hidden_states * (
486+
1 + c_scale_mlp.unsqueeze(1)
487+
) + c_shift_mlp.unsqueeze(1)
429488

430489
context_ff_output = self.ff_context(norm_encoder_hidden_states)
431490
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
491+
432492
if encoder_hidden_states.dtype == torch.float16:
433493
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
434494

435495
return encoder_hidden_states, hidden_states
436496

437497

498+
class FluxPosEmbed(nn.Module):
499+
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
500+
def __init__(self, theta: int, axes_dim: List[int]):
501+
super().__init__()
502+
self.theta = theta
503+
self.axes_dim = axes_dim
504+
505+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
506+
n_axes = ids.shape[-1]
507+
cos_out = []
508+
sin_out = []
509+
pos = ids.float()
510+
is_mps = ids.device.type == "mps"
511+
is_npu = ids.device.type == "npu"
512+
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
513+
for i in range(n_axes):
514+
cos, sin = get_1d_rotary_pos_embed(
515+
self.axes_dim[i],
516+
pos[:, i],
517+
theta=self.theta,
518+
repeat_interleave_real=True,
519+
use_real=True,
520+
freqs_dtype=freqs_dtype,
521+
)
522+
cos_out.append(cos)
523+
sin_out.append(sin)
524+
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
525+
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
526+
return freqs_cos, freqs_sin
527+
528+
438529
class FluxTransformer2DModel(
439530
ModelMixin,
440531
ConfigMixin,
@@ -537,10 +628,6 @@ def __init__(
537628

538629
self.gradient_checkpointing = False
539630

540-
# Using inherited methods from AttentionMixin
541-
542-
# Using inherited methods from AttentionMixin
543-
544631
def forward(
545632
self,
546633
hidden_states: torch.Tensor,
@@ -634,11 +721,7 @@ def forward(
634721
for index_block, block in enumerate(self.transformer_blocks):
635722
if torch.is_grad_enabled() and self.gradient_checkpointing:
636723
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
637-
block,
638-
hidden_states,
639-
encoder_hidden_states,
640-
temb,
641-
image_rotary_emb,
724+
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb
642725
)
643726

644727
else:
@@ -665,12 +748,7 @@ def forward(
665748

666749
for index_block, block in enumerate(self.single_transformer_blocks):
667750
if torch.is_grad_enabled() and self.gradient_checkpointing:
668-
hidden_states = self._gradient_checkpointing_func(
669-
block,
670-
hidden_states,
671-
temb,
672-
image_rotary_emb,
673-
)
751+
hidden_states = self._gradient_checkpointing_func(block, hidden_states, temb, image_rotary_emb)
674752

675753
else:
676754
hidden_states = block(

0 commit comments

Comments
 (0)