Skip to content

Commit 8b9d5b6

Browse files
committed
update
1 parent d99234f commit 8b9d5b6

File tree

2 files changed

+122
-62
lines changed

2 files changed

+122
-62
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from typing import Callable, List, Optional, Tuple, Union
1717

1818
import torch
19-
from torch._higher_order_ops.flex_attention import sdpa_dense
2019
import torch.nn.functional as F
2120
from torch import nn
2221

@@ -3597,6 +3596,7 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin):
35973596
value = torch.index_select(value, 2, select_index)
35983597

35993598
from torch.nn.attention import SDPBackend, sdpa_kernel
3599+
36003600
with sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION]):
36013601
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
36023602

src/diffusers/models/transformers/transformer_mochi.py

Lines changed: 121 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -13,44 +13,71 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from operator import ipow
1716
from typing import Any, Dict, Optional, Tuple
1817

1918
import torch
20-
from torch._prims_common import is_low_precision_dtype
2119
import torch.nn as nn
22-
from transformers.tokenization_utils_base import import_protobuf_decode_error
2320

2421
from ...configuration_utils import ConfigMixin, register_to_config
2522
from ...utils import is_torch_version, logging
2623
from ...utils.torch_utils import maybe_allow_in_graph
2724
from ..attention import FeedForward
28-
from ..attention_processor import Attention, MochiAttnProcessor2_0
25+
from ..attention_processor import MochiAttnProcessor2_0
2926
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
3027
from ..modeling_outputs import Transformer2DModelOutput
3128
from ..modeling_utils import ModelMixin
3229
from ..normalization import (
3330
AdaLayerNormContinuous,
34-
LuminaLayerNormContinuous,
3531
)
3632

3733

3834
logger = logging.get_logger(__name__) # pylint: disable=invalid-n
3935

4036

41-
class FP32ModulatedRMSNorm(nn.Module):
42-
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
37+
class MochiModulatedRMSNorm(nn.Module):
38+
def __init__(self, eps: float):
4339
super().__init__()
4440

4541
self.eps = eps
4642

4743
def forward(self, hidden_states, scale=None):
44+
hidden_states_dtype = hidden_states.dtype
45+
4846
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
49-
hidden_states = hidden_states.float() * torch.rsqrt(variance + self.eps)
47+
hidden_states = hidden_states.to(torch.float32) * torch.rsqrt(variance + self.eps)
5048

5149
if scale is not None:
5250
hidden_states = hidden_states * scale
5351

52+
hidden_states = hidden_states.to(hidden_states_dtype)
53+
54+
return hidden_states
55+
56+
57+
class MochiRMSNorm(nn.Module):
58+
def __init__(self, dim, eps: float, elementwise_affine=True):
59+
super().__init__()
60+
61+
self.eps = eps
62+
if elementwise_affine:
63+
self.weight = nn.Parameter(torch.ones(dim))
64+
else:
65+
self.weight = None
66+
67+
def forward(self, hidden_states):
68+
hidden_states_dtype = hidden_states.dtype
69+
70+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
71+
hidden_states = hidden_states.to(torch.float32) * torch.rsqrt(variance + self.eps)
72+
73+
if self.weight is not None:
74+
# convert into half-precision if necessary
75+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
76+
hidden_states = hidden_states.to(self.weight.dtype)
77+
hidden_states = hidden_states * self.weight
78+
79+
hidden_states = hidden_states.to(hidden_states_dtype)
80+
5481
return hidden_states
5582

5683

@@ -59,49 +86,28 @@ def __init__(
5986
self,
6087
embedding_dim: int,
6188
conditioning_embedding_dim: int,
62-
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
63-
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
64-
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
65-
# However, this is how it was implemented in the original code, and it's rather likely you should
66-
# set `elementwise_affine` to False.
67-
elementwise_affine=True,
6889
eps=1e-5,
6990
bias=True,
70-
norm_type="layer_norm",
71-
out_dim: Optional[int] = None,
7291
):
7392
super().__init__()
7493

7594
# AdaLN
7695
self.silu = nn.SiLU()
7796
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
78-
79-
if norm_type == "layer_norm":
80-
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
81-
elif norm_type == "rms_norm":
82-
self.norm = FP32ModulatedRMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
83-
else:
84-
raise ValueError(f"unknown norm_type {norm_type}")
85-
86-
self.linear_2 = None
87-
if out_dim is not None:
88-
self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
97+
self.norm = MochiModulatedRMSNorm(eps=eps)
8998

9099
def forward(
91100
self,
92101
x: torch.Tensor,
93102
conditioning_embedding: torch.Tensor,
94103
) -> torch.Tensor:
95-
output_dtype = x.dtype
96-
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
97-
emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
98-
scale = emb
99-
x = self.norm(x, (1 + scale.unsqueeze(1).float()))
104+
input_dtype = x.dtype
100105

101-
if self.linear_2 is not None:
102-
x = self.linear_2(x)
106+
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
107+
scale = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
108+
x = self.norm(x, (1 + scale.unsqueeze(1).to(torch.float32)))
103109

104-
return x.to(output_dtype)
110+
return x.to(input_dtype)
105111

106112

107113
class MochiRMSNormZero(nn.Module):
@@ -119,7 +125,7 @@ def __init__(
119125

120126
self.silu = nn.SiLU()
121127
self.linear = nn.Linear(embedding_dim, hidden_dim)
122-
self.norm = FP32ModulatedRMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
128+
self.norm = MochiModulatedRMSNorm(eps=eps)
123129

124130
def forward(
125131
self, hidden_states: torch.Tensor, emb: torch.Tensor
@@ -129,12 +135,76 @@ def forward(
129135
emb = self.linear(self.silu(emb))
130136
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
131137

132-
hidden_states = self.norm(hidden_states, (1 + scale_msa[:, None].float()))
138+
hidden_states = self.norm(hidden_states, (1 + scale_msa[:, None].to(torch.float32)))
133139
hidden_states = hidden_states.to(hidden_states_dtype)
134140

135141
return hidden_states, gate_msa, scale_mlp, gate_mlp
136142

137143

144+
class MochiAttention(nn.Module):
145+
def __init__(
146+
self,
147+
query_dim: int,
148+
processor: Optional["MochiAttnProcessor2_0"],
149+
heads: int = 8,
150+
dim_head: int = 64,
151+
dropout: float = 0.0,
152+
bias: bool = False,
153+
added_kv_proj_dim: Optional[int] = None,
154+
added_proj_bias: Optional[bool] = True,
155+
out_dim: int = None,
156+
out_context_dim: int = None,
157+
out_bias: bool = True,
158+
context_pre_only: bool = False,
159+
eps: float = 1e-5,
160+
):
161+
super().__init__()
162+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
163+
self.out_dim = out_dim if out_dim is not None else query_dim
164+
self.out_context_dim = out_context_dim if out_context_dim else query_dim
165+
self.context_pre_only = context_pre_only
166+
167+
self.heads = out_dim // dim_head if out_dim is not None else heads
168+
169+
self.norm_q = MochiRMSNorm(dim_head, eps)
170+
self.norm_k = MochiRMSNorm(dim_head, eps)
171+
self.norm_added_q = MochiRMSNorm(dim_head, eps)
172+
self.norm_added_k = MochiRMSNorm(dim_head, eps)
173+
174+
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
175+
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
176+
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
177+
178+
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
179+
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
180+
if self.context_pre_only is not None:
181+
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
182+
183+
self.to_out = nn.ModuleList([])
184+
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
185+
self.to_out.append(nn.Dropout(dropout))
186+
187+
if not self.context_pre_only:
188+
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
189+
190+
self.processor = processor
191+
192+
def forward(
193+
self,
194+
hidden_states: torch.Tensor,
195+
encoder_hidden_states: Optional[torch.Tensor] = None,
196+
attention_mask: Optional[torch.Tensor] = None,
197+
**kwargs,
198+
):
199+
return self.processor(
200+
self,
201+
hidden_states,
202+
encoder_hidden_states=encoder_hidden_states,
203+
attention_mask=attention_mask,
204+
**kwargs,
205+
)
206+
207+
138208
@maybe_allow_in_graph
139209
class MochiTransformerBlock(nn.Module):
140210
r"""
@@ -183,34 +253,28 @@ def __init__(
183253
embedding_dim=pooled_projection_dim,
184254
conditioning_embedding_dim=dim,
185255
eps=eps,
186-
elementwise_affine=False,
187-
norm_type="rms_norm",
188-
out_dim=None,
189256
)
190257

191-
self.attn1 = Attention(
258+
self.attn1 = MochiAttention(
192259
query_dim=dim,
193-
cross_attention_dim=None,
194260
heads=num_attention_heads,
195261
dim_head=attention_head_dim,
196262
bias=False,
197-
qk_norm=qk_norm,
198263
added_kv_proj_dim=pooled_projection_dim,
199264
added_proj_bias=False,
200265
out_dim=dim,
201266
out_context_dim=pooled_projection_dim,
202267
context_pre_only=context_pre_only,
203268
processor=MochiAttnProcessor2_0(),
204269
eps=1e-5,
205-
elementwise_affine=True,
206270
)
207271

208272
# TODO(aryan): norm_context layers are not needed when `context_pre_only` is True
209-
self.norm2 = FP32ModulatedRMSNorm(dim, eps=eps, elementwise_affine=False)
210-
self.norm2_context = FP32ModulatedRMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
273+
self.norm2 = MochiModulatedRMSNorm(eps=eps)
274+
self.norm2_context = MochiModulatedRMSNorm(eps=eps) if not self.context_pre_only else None
211275

212-
self.norm3 = FP32ModulatedRMSNorm(dim, eps=eps, elementwise_affine=False)
213-
self.norm3_context = FP32ModulatedRMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
276+
self.norm3 = MochiModulatedRMSNorm(eps)
277+
self.norm3_context = MochiModulatedRMSNorm(eps=eps) if not self.context_pre_only else None
214278

215279
self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False)
216280
self.ff_context = None
@@ -222,8 +286,8 @@ def __init__(
222286
bias=False,
223287
)
224288

225-
self.norm4 = FP32ModulatedRMSNorm(dim, eps=eps, elementwise_affine=False)
226-
self.norm4_context = FP32ModulatedRMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
289+
self.norm4 = MochiModulatedRMSNorm(eps=eps)
290+
self.norm4_context = MochiModulatedRMSNorm(eps=eps)
227291

228292
def forward(
229293
self,
@@ -249,26 +313,22 @@ def forward(
249313
attention_mask=joint_attention_mask,
250314
)
251315

252-
hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1)).to(
253-
hidden_states.dtype
254-
)
255-
norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).float())).to(hidden_states.dtype)
316+
hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1))
317+
norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).to(torch.float32)))
256318
ff_output = self.ff(norm_hidden_states)
257-
hidden_states = hidden_states + self.norm4(ff_output, torch.tanh(gate_mlp).unsqueeze(1)).to(
258-
hidden_states.dtype
259-
)
319+
hidden_states = hidden_states + self.norm4(ff_output, torch.tanh(gate_mlp).unsqueeze(1))
260320

261321
if not self.context_pre_only:
262322
encoder_hidden_states = encoder_hidden_states + self.norm2_context(
263323
context_attn_hidden_states, torch.tanh(enc_gate_msa).unsqueeze(1)
264-
).to(encoder_hidden_states.dtype)
324+
)
265325
norm_encoder_hidden_states = self.norm3_context(
266-
encoder_hidden_states, (1 + enc_scale_mlp.unsqueeze(1).float())
267-
).to(encoder_hidden_states.dtype)
326+
encoder_hidden_states, (1 + enc_scale_mlp.unsqueeze(1).to(torch.float32))
327+
)
268328
context_ff_output = self.ff_context(norm_encoder_hidden_states)
269329
encoder_hidden_states = encoder_hidden_states + self.norm4_context(
270330
context_ff_output, torch.tanh(enc_gate_mlp).unsqueeze(1)
271-
).to(encoder_hidden_states.dtype)
331+
)
272332

273333
return hidden_states, encoder_hidden_states
274334

0 commit comments

Comments
 (0)