Skip to content

Commit 0e9e281

Browse files
committed
fix
1 parent 05ebd6c commit 0e9e281

File tree

4 files changed

+258
-36
lines changed

4 files changed

+258
-36
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,218 @@ def fuse_projections(self, fuse=True):
717717
self.fused_projections = fuse
718718

719719

720+
class AsymmetricAttention(nn.Module):
721+
def __init__(
722+
self,
723+
query_dim: int,
724+
query_context_dim: int,
725+
num_attention_heads: int = 8,
726+
attention_head_dim: int = 64,
727+
bias: bool = False,
728+
context_bias: bool = False,
729+
out_dim: Optional[int] = None,
730+
out_context_dim: Optional[int] = None,
731+
qk_norm: Optional[str] = None,
732+
eps: float = 1e-5,
733+
elementwise_affine: bool = True,
734+
processor: Optional["AttnProcessor"] = None,
735+
) -> None:
736+
super().__init__()
737+
738+
from .normalization import RMSNorm
739+
740+
self.query_dim = query_dim
741+
self.query_context_dim = query_context_dim
742+
self.inner_dim = out_dim if out_dim is not None else num_attention_heads * attention_head_dim
743+
self.out_dim = out_dim if out_dim is not None else query_dim
744+
745+
self.scale = attention_head_dim ** -0.5
746+
self.num_attention_heads = out_dim // attention_head_dim if out_dim is not None else num_attention_heads
747+
748+
if qk_norm is None:
749+
self.norm_q = None
750+
self.norm_k = None
751+
self.norm_context_q = None
752+
self.norm_context_k = None
753+
elif qk_norm == "rms_norm":
754+
self.norm_q = RMSNorm(attention_head_dim, eps=eps, elementwise_affine=elementwise_affine)
755+
self.norm_k = RMSNorm(attention_head_dim, eps=eps, elementwise_affine=elementwise_affine)
756+
self.norm_context_q = RMSNorm(attention_head_dim, eps=eps, elementwise_affine=elementwise_affine)
757+
self.norm_context_k = RMSNorm(attention_head_dim, eps=eps, elementwise_affine=elementwise_affine)
758+
else:
759+
raise ValueError((f"Unknown qk_norm: {qk_norm}. Should be None or `rms_norm`."))
760+
761+
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
762+
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
763+
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
764+
765+
self.to_context_q = nn.Linear(query_context_dim, self.inner_dim, bias=context_bias)
766+
self.to_context_k = nn.Linear(query_context_dim, self.inner_dim, bias=context_bias)
767+
self.to_context_v = nn.Linear(query_context_dim, self.inner_dim, bias=context_bias)
768+
769+
# TODO(aryan): Take care of dropouts for training purpose in future
770+
self.to_out = nn.ModuleList([
771+
nn.Linear(self.inner_dim, self.out_dim)
772+
])
773+
774+
self.to_context_out = None
775+
if out_context_dim is not None:
776+
self.to_context_out = nn.ModuleList([
777+
nn.Linear(self.inner_dim, out_context_dim)
778+
])
779+
780+
if processor is None:
781+
processor = AsymmetricAttnProcessor2_0()
782+
783+
self.set_processor(processor)
784+
785+
def set_processor(self, processor: "AttnProcessor") -> None:
786+
r"""
787+
Set the attention processor to use.
788+
789+
Args:
790+
processor (`AttnProcessor`):
791+
The attention processor to use.
792+
"""
793+
# if current processor is in `self._modules` and if passed `processor` is not, we need to
794+
# pop `processor` from `self._modules`
795+
if (
796+
hasattr(self, "processor")
797+
and isinstance(self.processor, torch.nn.Module)
798+
and not isinstance(processor, torch.nn.Module)
799+
):
800+
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
801+
self._modules.pop("processor")
802+
803+
self.processor = processor
804+
805+
def get_processor(self) -> "AttentionProcessor":
806+
r"""
807+
Get the attention processor in use.
808+
809+
Returns:
810+
"AttentionProcessor": The attention processor in use.
811+
"""
812+
return self.processor
813+
814+
def forward(
815+
self,
816+
hidden_states: torch.Tensor,
817+
encoder_hidden_states: Optional[torch.Tensor] = None,
818+
attention_mask: Optional[torch.Tensor] = None,
819+
**cross_attention_kwargs,
820+
) -> torch.Tensor:
821+
r"""
822+
The forward method of the `Attention` class.
823+
824+
Args:
825+
hidden_states (`torch.Tensor`):
826+
The hidden states of the query.
827+
encoder_hidden_states (`torch.Tensor`, *optional*):
828+
The hidden states of the encoder.
829+
attention_mask (`torch.Tensor`, *optional*):
830+
The attention mask to use. If `None`, no mask is applied.
831+
**cross_attention_kwargs:
832+
Additional keyword arguments to pass along to the cross attention.
833+
834+
Returns:
835+
`torch.Tensor`: The output of the attention layer.
836+
"""
837+
# The `Attention` class can call different attention processors / attention functions
838+
# here we simply pass along all tensors to the selected processor class
839+
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
840+
841+
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
842+
quiet_attn_parameters = {"ip_adapter_masks"}
843+
unused_kwargs = [
844+
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
845+
]
846+
if len(unused_kwargs) > 0:
847+
logger.warning(
848+
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
849+
)
850+
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
851+
852+
return self.processor(
853+
self,
854+
hidden_states,
855+
encoder_hidden_states=encoder_hidden_states,
856+
attention_mask=attention_mask,
857+
**cross_attention_kwargs,
858+
)
859+
860+
861+
class AsymmetricAttnProcessor2_0:
862+
r"""
863+
Processor for implementing Asymmetric SDPA as described in Genmo/Mochi (TODO(aryan) add link).
864+
"""
865+
866+
def __init__(self):
867+
if not hasattr(F, "scaled_dot_product_attention"):
868+
raise ImportError("AsymmetricAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
869+
870+
def __call__(
871+
self,
872+
attn: AsymmetricAttention,
873+
hidden_states: torch.Tensor,
874+
encoder_hidden_states: torch.Tensor,
875+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
876+
) -> torch.Tensor:
877+
batch_size = hidden_states.size(0)
878+
query = attn.to_q(hidden_states)
879+
key = attn.to_k(hidden_states)
880+
value = attn.to_v(hidden_states)
881+
882+
query_context = attn.to_context_q(encoder_hidden_states)
883+
key_context = attn.to_context_k(encoder_hidden_states)
884+
value_context = attn.to_context_v(encoder_hidden_states)
885+
886+
inner_dim = key.shape[-1]
887+
head_dim = inner_dim / attn.num_attention_heads
888+
889+
query = query.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2)
890+
key = key.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2)
891+
value = value.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2)
892+
893+
query_context = query_context.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2)
894+
key_context = key_context.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2)
895+
value_context = value_context.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2)
896+
897+
if attn.norm_q is not None:
898+
query = attn.norm_q(query)
899+
if attn.norm_k is not None:
900+
key = attn.norm_k(key)
901+
902+
if attn.norm_context_q is not None:
903+
query_context = attn.norm_context_q(query_context)
904+
if attn.norm_context_k is not None:
905+
key_context = attn.norm_context_k(key_context)
906+
907+
if image_rotary_emb is not None:
908+
from .embeddings import apply_rotary_emb
909+
query = apply_rotary_emb(query, image_rotary_emb)
910+
key = apply_rotary_emb(key, image_rotary_emb)
911+
912+
sequence_length = query.size(1)
913+
context_sequence_length = query_context.size(1)
914+
915+
query = torch.cat([query, query_context], dim=1)
916+
key = torch.cat([key, key_context], dim=1)
917+
value = torch.cat([value, value_context], dim=1)
918+
919+
hidden_states = F.scaled_dot_product_attention(
920+
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
921+
)
922+
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
923+
hidden_states = hidden_states.to(query.dtype)
924+
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes([sequence_length, context_sequence_length], dim=1)
925+
926+
hidden_states = attn.to_out[0](hidden_states)
927+
encoder_hidden_states = attn.to_context_out[0](encoder_hidden_states)
928+
929+
return hidden_states, encoder_hidden_states
930+
931+
720932
class AttnProcessor:
721933
r"""
722934
Default processor for performing attention-related computations.

src/diffusers/models/embeddings.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1304,16 +1304,16 @@ def forward(self, timestep, caption_feat, caption_mask):
13041304

13051305
class MochiCombinedTimestepCaptionEmbedding(nn.Module):
13061306
def __init__(
1307-
self, embedding_dim: int, pooled_projection_dim: int, time_embed_dim: int = 256, num_attention_heads: int = 8
1307+
self, embedding_dim: int, pooled_projection_dim: int, text_embed_dim: int, time_embed_dim: int = 256, num_attention_heads: int = 8
13081308
) -> None:
13091309
super().__init__()
13101310

13111311
self.time_proj = Timesteps(num_channels=time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0.0)
13121312
self.timestep_embedder = TimestepEmbedding(in_channels=time_embed_dim, time_embed_dim=embedding_dim)
13131313
self.pooler = MochiAttentionPool(
1314-
num_attention_heads=num_attention_heads, embed_dim=pooled_projection_dim, output_dim=embedding_dim
1314+
num_attention_heads=num_attention_heads, embed_dim=text_embed_dim, output_dim=embedding_dim
13151315
)
1316-
self.caption_proj = nn.Linear(embedding_dim, pooled_projection_dim)
1316+
self.caption_proj = nn.Linear(text_embed_dim, pooled_projection_dim)
13171317

13181318
def forward(
13191319
self,

src/diffusers/models/transformers/transformer_mochi.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
from ...configuration_utils import ConfigMixin, register_to_config
2222
from ...utils import logging
2323
from ...utils.torch_utils import maybe_allow_in_graph
24-
from ..attention import Attention, FeedForward, JointAttnProcessor2_0
24+
from ..attention import FeedForward
25+
from ..attention_processor import AsymmetricAttention, AsymmetricAttnProcessor2_0
2526
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
2627
from ..modeling_outputs import Transformer2DModelOutput
2728
from ..modeling_utils import ModelMixin
@@ -46,23 +47,27 @@ def __init__(
4647
super().__init__()
4748

4849
self.context_pre_only = context_pre_only
50+
self.ff_inner_dim = (4 * dim * 2) // 3
51+
self.ff_context_inner_dim = (4 * pooled_projection_dim * 2) // 3
4952

5053
self.norm1 = MochiRMSNormZero(dim, 4 * dim)
5154

52-
if context_pre_only:
55+
if not context_pre_only:
5356
self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim)
5457
else:
55-
self.norm1_context = RMSNorm(pooled_projection_dim, eps=1e-6, elementwise_affine=False)
58+
self.norm1_context = nn.Linear(dim, pooled_projection_dim)
5659

57-
self.attn = Attention(
60+
self.attn = AsymmetricAttention(
5861
query_dim=dim,
59-
heads=num_attention_heads,
62+
query_context_dim=pooled_projection_dim,
63+
num_attention_heads=num_attention_heads,
6064
attention_head_dim=attention_head_dim,
61-
out_dim=4 * dim,
65+
out_dim=dim,
66+
out_context_dim=None if context_pre_only else pooled_projection_dim,
6267
qk_norm=qk_norm,
6368
eps=1e-6,
6469
elementwise_affine=False,
65-
processor=JointAttnProcessor2_0(),
70+
processor=AsymmetricAttnProcessor2_0(),
6671
)
6772

6873
self.norm2 = RMSNorm(dim, eps=1e-6, elementwise_affine=False)
@@ -71,8 +76,10 @@ def __init__(
7176
self.norm3 = RMSNorm(dim, eps=1e-6, elementwise_affine=False)
7277
self.norm3_context = RMSNorm(pooled_projection_dim, eps=1e-56, elementwise_affine=False)
7378

74-
self.ff = FeedForward(dim, mult=4, activation_fn=activation_fn)
75-
self.ff_context = FeedForward(pooled_projection_dim, mult=4, activation_fn=activation_fn)
79+
self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False)
80+
self.ff_context = None
81+
if not context_pre_only:
82+
self.ff_context = FeedForward(pooled_projection_dim, inner_dim=self.ff_context_inner_dim, activation_fn=activation_fn, bias=False)
7683

7784
self.norm4 = RMSNorm(dim, eps=1e-6, elementwise_affine=False)
7885
self.norm4_context = RMSNorm(pooled_projection_dim, eps=1e-56, elementwise_affine=False)
@@ -110,10 +117,10 @@ def forward(
110117
)
111118

112119
ff_output = self.ff(hidden_states)
113-
context_ff_output = self.ff_context(encoder_hidden_states)
114-
115120
hidden_states = hidden_states + ff_output * torch.tanh(gate_mlp).unsqueeze(1)
121+
116122
if not self.context_pre_only:
123+
context_ff_output = self.ff_context(encoder_hidden_states)
117124
encoder_hidden_states = encoder_hidden_states + context_ff_output * torch.tanh(enc_gate_mlp).unsqueeze(0)
118125

119126
return hidden_states, encoder_hidden_states
@@ -131,11 +138,9 @@ def __init__(
131138
attention_head_dim: int = 128,
132139
num_layers: int = 48,
133140
pooled_projection_dim: int = 1536,
134-
in_channels=12,
141+
in_channels: int = 12,
135142
out_channels: Optional[int] = None,
136143
qk_norm: str = "rms_norm",
137-
timestep_mlp_bias=True,
138-
timestep_scale=1000.0,
139144
text_embed_dim: int = 4096,
140145
time_embed_dim: int = 256,
141146
activation_fn: str = "swiglu",
@@ -146,19 +151,20 @@ def __init__(
146151
inner_dim = num_attention_heads * attention_head_dim
147152
out_channels = out_channels or in_channels
148153

149-
self.time_embed = MochiCombinedTimestepCaptionEmbedding(
150-
embedding_dim=text_embed_dim,
151-
pooled_projection_dim=pooled_projection_dim,
152-
time_embed_dim=time_embed_dim,
153-
num_attention_heads=8,
154-
)
155-
156154
self.patch_embed = PatchEmbed(
157155
patch_size=patch_size,
158156
in_channels=in_channels,
159157
embed_dim=inner_dim,
160158
)
161159

160+
self.time_embed = MochiCombinedTimestepCaptionEmbedding(
161+
embedding_dim=inner_dim,
162+
pooled_projection_dim=pooled_projection_dim,
163+
text_embed_dim=text_embed_dim,
164+
time_embed_dim=time_embed_dim,
165+
num_attention_heads=8,
166+
)
167+
162168
self.pos_frequencies = nn.Parameter(torch.empty(3, num_attention_heads, attention_head_dim // 2))
163169

164170
self.transformer_blocks = nn.ModuleList(
@@ -170,7 +176,7 @@ def __init__(
170176
pooled_projection_dim=pooled_projection_dim,
171177
qk_norm=qk_norm,
172178
activation_fn=activation_fn,
173-
context_pre_only=i < num_layers - 1,
179+
context_pre_only=i == num_layers - 1,
174180
)
175181
for i in range(num_layers)
176182
]
@@ -196,7 +202,7 @@ def forward(
196202
post_patch_height = height // p
197203
post_patch_width = width // p
198204

199-
temb, caption_proj = self.time_embed(timestep, encoder_hidden_states, encoder_attention_mask)
205+
temb, encoder_hidden_states = self.time_embed(timestep, encoder_hidden_states, encoder_attention_mask)
200206

201207
hidden_states = self.patch_embed(hidden_states)
202208

0 commit comments

Comments
 (0)