Skip to content

Commit 80af1f0

Browse files
committed
update
1 parent 8c06092 commit 80af1f0

File tree

3 files changed

+98
-8
lines changed

3 files changed

+98
-8
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def __init__(
124124
context_pre_only=None,
125125
pre_only=False,
126126
elementwise_affine: bool = True,
127+
is_causal: bool = False,
127128
):
128129
super().__init__()
129130

@@ -146,6 +147,7 @@ def __init__(
146147
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
147148
self.context_pre_only = context_pre_only
148149
self.pre_only = pre_only
150+
self.is_causal = is_causal
149151

150152
# we make use of this private variable to know whether this class is loaded
151153
# with an deprecated state dict so that we can convert it on the fly
@@ -195,8 +197,8 @@ def __init__(
195197
self.norm_q = RMSNorm(dim_head, eps=eps)
196198
self.norm_k = RMSNorm(dim_head, eps=eps)
197199
elif qk_norm == "l2":
198-
self.norm_q = LpNorm(p=2, eps=eps)
199-
self.norm_k = LpNorm(p=2, eps=eps)
200+
self.norm_q = LpNorm(p=2, dim=-1)
201+
self.norm_k = LpNorm(p=2, dim=-1)
200202
else:
201203
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'")
202204

@@ -2720,6 +2722,91 @@ def __call__(
27202722
return hidden_states
27212723

27222724

2725+
class MochiVaeAttnProcessor2_0:
2726+
r"""
2727+
Attention processor used in Mochi VAE.
2728+
"""
2729+
2730+
def __init__(self):
2731+
if not hasattr(F, "scaled_dot_product_attention"):
2732+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
2733+
2734+
def __call__(
2735+
self,
2736+
attn: Attention,
2737+
hidden_states: torch.Tensor,
2738+
encoder_hidden_states: Optional[torch.Tensor] = None,
2739+
attention_mask: Optional[torch.Tensor] = None,
2740+
) -> torch.Tensor:
2741+
residual = hidden_states
2742+
is_single_frame = hidden_states.shape[1] == 1
2743+
2744+
batch_size, sequence_length, _ = (
2745+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2746+
)
2747+
2748+
if attention_mask is not None:
2749+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
2750+
# scaled_dot_product_attention expects attention_mask shape to be
2751+
# (batch, heads, source_length, target_length)
2752+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
2753+
2754+
if is_single_frame:
2755+
hidden_states = attn.to_v(hidden_states)
2756+
2757+
# linear proj
2758+
hidden_states = attn.to_out[0](hidden_states)
2759+
# dropout
2760+
hidden_states = attn.to_out[1](hidden_states)
2761+
2762+
if attn.residual_connection:
2763+
hidden_states = hidden_states + residual
2764+
2765+
hidden_states = hidden_states / attn.rescale_output_factor
2766+
return hidden_states
2767+
2768+
query = attn.to_q(hidden_states)
2769+
2770+
if encoder_hidden_states is None:
2771+
encoder_hidden_states = hidden_states
2772+
2773+
key = attn.to_k(encoder_hidden_states)
2774+
value = attn.to_v(encoder_hidden_states)
2775+
2776+
inner_dim = key.shape[-1]
2777+
head_dim = inner_dim // attn.heads
2778+
2779+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2780+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2781+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2782+
2783+
if attn.norm_q is not None:
2784+
query = attn.norm_q(query)
2785+
if attn.norm_k is not None:
2786+
key = attn.norm_k(key)
2787+
2788+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
2789+
# TODO: add support for attn.scale when we move to Torch 2.1
2790+
hidden_states = F.scaled_dot_product_attention(
2791+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=attn.is_causal
2792+
)
2793+
2794+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2795+
hidden_states = hidden_states.to(query.dtype)
2796+
2797+
# linear proj
2798+
hidden_states = attn.to_out[0](hidden_states)
2799+
# dropout
2800+
hidden_states = attn.to_out[1](hidden_states)
2801+
2802+
if attn.residual_connection:
2803+
hidden_states = hidden_states + residual
2804+
2805+
hidden_states = hidden_states / attn.rescale_output_factor
2806+
2807+
return hidden_states
2808+
2809+
27232810
class StableAudioAttnProcessor2_0:
27242811
r"""
27252812
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is

src/diffusers/models/autoencoders/autoencoder_kl_mochi.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ...utils import logging
2424
from ...utils.accelerate_utils import apply_forward_hook
2525
from ..activations import get_activation
26-
from ..attention_processor import Attention
26+
from ..attention_processor import Attention, MochiVaeAttnProcessor2_0
2727
from ..modeling_outputs import AutoencoderKLOutput
2828
from ..modeling_utils import ModelMixin
2929
from .autoencoder_kl_cogvideox import CogVideoXCausalConv3d
@@ -174,6 +174,8 @@ def __init__(
174174
heads=out_channels // 32,
175175
dim_head=32,
176176
qk_norm="l2",
177+
is_causal=True,
178+
processor=MochiVaeAttnProcessor2_0(),
177179
)
178180
)
179181
else:
@@ -280,6 +282,8 @@ def __init__(
280282
heads=in_channels // 32,
281283
dim_head=32,
282284
qk_norm="l2",
285+
is_causal=True,
286+
processor=MochiVaeAttnProcessor2_0(),
283287
)
284288
)
285289
else:
@@ -484,7 +488,7 @@ def __init__(
484488

485489
self.nonlinearity = get_activation(act_fn)
486490

487-
self.fourier_features = FourierFeatures()
491+
# self.fourier_features = FourierFeatures()
488492
self.proj_in = nn.Linear(in_channels, block_out_channels[0])
489493
self.block_in = MochiMidBlock3D(
490494
in_channels=block_out_channels[0], num_layers=layers_per_block[0], add_attention=add_attention_block[0]
@@ -517,7 +521,7 @@ def forward(
517521
new_conv_cache = {}
518522
conv_cache = conv_cache or {}
519523

520-
hidden_states = self.fourier_features(hidden_states)
524+
# hidden_states = self.fourier_features(hidden_states)
521525

522526
hidden_states = hidden_states.permute(0, 2, 3, 4, 1)
523527
hidden_states = self.proj_in(hidden_states)

src/diffusers/models/normalization.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -557,12 +557,11 @@ def forward(self, x):
557557

558558

559559
class LpNorm(nn.Module):
560-
def __init__(self, p: int = 2, dim: int = -1, eps: float = 1e-12):
560+
def __init__(self, p: int = 2, dim: int = -1):
561561
super().__init__()
562562

563563
self.p = p
564564
self.dim = dim
565-
self.eps = eps
566565

567566
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
568-
return F.normalize(hidden_states, p=self.p, dim=self.dim, eps=self.eps)
567+
return F.normalize(hidden_states, p=self.p, dim=self.dim)

0 commit comments

Comments
 (0)