Skip to content

Commit 2fd2ec4

Browse files
committed
fixes
1 parent 85c8734 commit 2fd2ec4

File tree

4 files changed

+427
-164
lines changed

4 files changed

+427
-164
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1795,8 +1795,7 @@ def __call__(
17951795
# dropout
17961796
hidden_states = attn.to_out[1](hidden_states)
17971797

1798-
if hasattr(attn, "to_add_out"):
1799-
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1798+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
18001799

18011800
return hidden_states, encoder_hidden_states
18021801
else:
@@ -3082,6 +3081,89 @@ def __call__(
30823081
return hidden_states
30833082

30843083

3084+
class MochiAttnProcessor2_0:
3085+
"""Attention processor used in Mochi."""
3086+
3087+
def __init__(self):
3088+
if not hasattr(F, "scaled_dot_product_attention"):
3089+
raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
3090+
3091+
def __call__(
3092+
self,
3093+
attn: Attention,
3094+
hidden_states: torch.Tensor,
3095+
encoder_hidden_states: torch.Tensor,
3096+
attention_mask: Optional[torch.Tensor] = None,
3097+
image_rotary_emb: Optional[torch.Tensor] = None,
3098+
) -> torch.Tensor:
3099+
breakpoint()
3100+
batch_size = hidden_states.size(0)
3101+
3102+
query = attn.to_q(hidden_states)
3103+
key = attn.to_k(hidden_states)
3104+
value = attn.to_v(hidden_states)
3105+
3106+
query = query.unflatten(2, (attn.heads, -1))
3107+
key = key.unflatten(2, (attn.heads, -1))
3108+
value = value.unflatten(2, (attn.heads, -1))
3109+
3110+
if attn.norm_q is not None:
3111+
query = attn.norm_q(query)
3112+
if attn.norm_k is not None:
3113+
key = attn.norm_k(key)
3114+
3115+
encoder_query = attn.add_q_proj(encoder_hidden_states)
3116+
encoder_key = attn.add_k_proj(encoder_hidden_states)
3117+
encoder_value = attn.add_v_proj(encoder_hidden_states)
3118+
3119+
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
3120+
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
3121+
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
3122+
3123+
if attn.norm_added_q is not None:
3124+
encoder_query = attn.norm_added_q(encoder_query)
3125+
if attn.norm_added_k is not None:
3126+
encoder_key = attn.norm_added_k(encoder_key)
3127+
3128+
if image_rotary_emb is not None:
3129+
def apply_rotary_emb(x, freqs_cos, freqs_sin):
3130+
x_even = x[..., 0::2].float()
3131+
x_odd = x[..., 1::2].float()
3132+
3133+
cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
3134+
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
3135+
3136+
return torch.stack([cos, sin], dim=-1).flatten(-2)
3137+
3138+
query = apply_rotary_emb(query, *image_rotary_emb)
3139+
key = apply_rotary_emb(key, *image_rotary_emb)
3140+
3141+
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
3142+
encoder_query, encoder_key, encoder_value = encoder_query.transpose(1, 2), encoder_key.transpose(1, 2), encoder_value.transpose(1, 2)
3143+
3144+
sequence_length = query.size(2)
3145+
encoder_sequence_length = encoder_query.size(2)
3146+
3147+
query = torch.cat([query, encoder_query], dim=2)
3148+
key = torch.cat([key, encoder_key], dim=2)
3149+
value = torch.cat([value, encoder_value], dim=2)
3150+
3151+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
3152+
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
3153+
hidden_states = hidden_states.to(query.dtype)
3154+
3155+
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes((sequence_length, encoder_sequence_length), dim=1)
3156+
3157+
# linear proj
3158+
hidden_states = attn.to_out[0](hidden_states)
3159+
# dropout
3160+
hidden_states = attn.to_out[1](hidden_states)
3161+
3162+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
3163+
3164+
return hidden_states, encoder_hidden_states
3165+
3166+
30853167
class FusedAttnProcessor2_0:
30863168
r"""
30873169
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses

src/diffusers/models/normalization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,13 +246,13 @@ class MochiRMSNormZero(nn.Module):
246246
"""
247247

248248
def __init__(
249-
self, embedding_dim: int, hidden_dim: int, norm_eps: float = 1e-5, elementwise_affine: bool = False
249+
self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False
250250
) -> None:
251251
super().__init__()
252252

253253
self.silu = nn.SiLU()
254254
self.linear = nn.Linear(embedding_dim, hidden_dim)
255-
self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=elementwise_affine)
255+
self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
256256

257257
def forward(
258258
self, hidden_states: torch.Tensor, emb: torch.Tensor

src/diffusers/models/transformers/transformer_mochi.py

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ...utils import logging
2323
from ...utils.torch_utils import maybe_allow_in_graph
2424
from ..attention import FeedForward
25-
from ..attention_processor import Attention, FluxAttnProcessor2_0
25+
from ..attention_processor import Attention, MochiAttnProcessor2_0
2626
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
2727
from ..modeling_outputs import Transformer2DModelOutput
2828
from ..modeling_utils import ModelMixin
@@ -43,22 +43,23 @@ def __init__(
4343
qk_norm: str = "rms_norm",
4444
activation_fn: str = "swiglu",
4545
context_pre_only: bool = True,
46+
eps: float = 1e-6,
4647
) -> None:
4748
super().__init__()
4849

4950
self.context_pre_only = context_pre_only
5051
self.ff_inner_dim = (4 * dim * 2) // 3
5152
self.ff_context_inner_dim = (4 * pooled_projection_dim * 2) // 3
5253

53-
self.norm1 = MochiRMSNormZero(dim, 4 * dim)
54+
self.norm1 = MochiRMSNormZero(dim, 4 * dim, eps=eps, elementwise_affine=False)
5455

5556
if not context_pre_only:
56-
self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim)
57+
self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim, eps=eps, elementwise_affine=False)
5758
else:
5859
self.norm1_context = LuminaLayerNormContinuous(
5960
embedding_dim=pooled_projection_dim,
6061
conditioning_embedding_dim=dim,
61-
eps=1e-6,
62+
eps=eps,
6263
elementwise_affine=False,
6364
norm_type="rms_norm",
6465
out_dim=None,
@@ -76,16 +77,16 @@ def __init__(
7677
out_dim=dim,
7778
out_context_dim=pooled_projection_dim,
7879
context_pre_only=context_pre_only,
79-
processor=FluxAttnProcessor2_0(),
80-
eps=1e-6,
80+
processor=MochiAttnProcessor2_0(),
81+
eps=eps,
8182
elementwise_affine=True,
8283
)
8384

84-
self.norm2 = RMSNorm(dim, eps=1e-6, elementwise_affine=False)
85-
self.norm2_context = RMSNorm(pooled_projection_dim, eps=1e-6, elementwise_affine=False)
85+
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=False)
86+
self.norm2_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
8687

87-
self.norm3 = RMSNorm(dim, eps=1e-6, elementwise_affine=False)
88-
self.norm3_context = RMSNorm(pooled_projection_dim, eps=1e-56, elementwise_affine=False)
88+
self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=False)
89+
self.norm3_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
8990

9091
self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False)
9192
self.ff_context = None
@@ -94,8 +95,8 @@ def __init__(
9495
pooled_projection_dim, inner_dim=self.ff_context_inner_dim, activation_fn=activation_fn, bias=False
9596
)
9697

97-
self.norm4 = RMSNorm(dim, eps=1e-6, elementwise_affine=False)
98-
self.norm4_context = RMSNorm(pooled_projection_dim, eps=1e-56, elementwise_affine=False)
98+
self.norm4 = RMSNorm(dim, eps=eps, elementwise_affine=False)
99+
self.norm4_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
99100

100101
def forward(
101102
self,
@@ -104,6 +105,7 @@ def forward(
104105
temb: torch.Tensor,
105106
image_rotary_emb: Optional[torch.Tensor] = None,
106107
) -> Tuple[torch.Tensor, torch.Tensor]:
108+
breakpoint()
107109
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
108110

109111
if not self.context_pre_only:
@@ -140,6 +142,40 @@ def forward(
140142
return hidden_states, encoder_hidden_states
141143

142144

145+
class MochiRoPE(nn.Module):
146+
def __init__(self, base_height: int = 192, base_width: int = 192, theta: float = 10000.0) -> None:
147+
super().__init__()
148+
149+
self.target_area = base_height * base_width
150+
151+
def _centers(self, start, stop, num, device, dtype) -> torch.Tensor:
152+
edges = torch.linspace(start, stop, num + 1, device=device, dtype=dtype)
153+
return (edges[:-1] + edges[1:]) / 2
154+
155+
def _get_positions(self, num_frames: int, height: int, width: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
156+
scale = (self.target_area / (height * width)) ** 0.5
157+
158+
t = torch.arange(num_frames, device=device, dtype=dtype)
159+
h = self._centers(-height * scale / 2, height * scale / 2, height, device, dtype)
160+
w = self._centers(-width * scale / 2, width * scale / 2, width, device, dtype)
161+
162+
grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij")
163+
164+
positions = torch.stack([grid_t, grid_h, grid_w], dim=-1).view(-1, 3)
165+
return positions
166+
167+
def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
168+
freqs = torch.einsum("nd,dhf->nhf", pos, freqs)
169+
freqs_cos = torch.cos(freqs)
170+
freqs_sin = torch.sin(freqs)
171+
return freqs_cos, freqs_sin
172+
173+
def forward(self, pos_frequencies: torch.Tensor, num_frames: int, height: int, width: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> Tuple[torch.Tensor, torch.Tensor]:
174+
pos = self._get_positions(num_frames, height, width, device, dtype)
175+
rope_cos, rope_sin = self._create_rope(pos_frequencies, pos)
176+
return rope_cos, rope_sin
177+
178+
143179
@maybe_allow_in_graph
144180
class MochiTransformer3DModel(ModelMixin, ConfigMixin):
145181
_supports_gradient_checkpointing = True
@@ -169,6 +205,7 @@ def __init__(
169205
patch_size=patch_size,
170206
in_channels=in_channels,
171207
embed_dim=inner_dim,
208+
pos_embed_type=None,
172209
)
173210

174211
self.time_embed = MochiCombinedTimestepCaptionEmbedding(
@@ -180,6 +217,7 @@ def __init__(
180217
)
181218

182219
self.pos_frequencies = nn.Parameter(torch.empty(3, num_attention_heads, attention_head_dim // 2))
220+
self.rope = MochiRoPE()
183221

184222
self.transformer_blocks = nn.ModuleList(
185223
[
@@ -207,7 +245,6 @@ def forward(
207245
encoder_hidden_states: torch.Tensor,
208246
timestep: torch.LongTensor,
209247
encoder_attention_mask: torch.Tensor,
210-
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
211248
return_dict: bool = True,
212249
) -> torch.Tensor:
213250
batch_size, num_channels, num_frames, height, width = hidden_states.shape
@@ -224,6 +261,8 @@ def forward(
224261
hidden_states = self.patch_embed(hidden_states)
225262
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2)
226263

264+
image_rotary_emb = self.rope(self.pos_frequencies, num_frames, post_patch_height, post_patch_width, device=hidden_states.device, dtype=torch.float32)
265+
227266
for i, block in enumerate(self.transformer_blocks):
228267
hidden_states, encoder_hidden_states = block(
229268
hidden_states=hidden_states,

0 commit comments

Comments
 (0)