Skip to content

Commit 11ce6b8

Browse files
committed
update
1 parent 3c70b54 commit 11ce6b8

File tree

2 files changed

+35
-6
lines changed

2 files changed

+35
-6
lines changed

src/diffusers/models/embeddings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1597,9 +1597,9 @@ def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.
15971597
input_dtype = x.dtype
15981598
assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens.
15991599
assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
1600-
mask = mask[:, :, None].to(dtype=torch.float32)
1600+
mask = mask[:, :, None].to(dtype=x.dtype)
16011601
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
1602-
pooled = (x.to(torch.float32) * mask).sum(dim=1, keepdim=keepdim)
1602+
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
16031603
return pooled.to(input_dtype)
16041604

16051605
def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:

src/diffusers/models/transformers/transformer_mochi.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import numbers
1617
from typing import Any, Dict, Optional, Tuple
1718

1819
import torch
@@ -54,6 +55,34 @@ def forward(self, hidden_states, scale=None):
5455
return hidden_states
5556

5657

58+
class MochiRMSNorm(nn.Module):
59+
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
60+
super().__init__()
61+
62+
self.eps = eps
63+
64+
if isinstance(dim, numbers.Integral):
65+
dim = (dim,)
66+
67+
self.dim = torch.Size(dim)
68+
69+
if elementwise_affine:
70+
self.weight = nn.Parameter(torch.ones(dim))
71+
else:
72+
self.weight = None
73+
74+
def forward(self, hidden_states):
75+
input_dtype = hidden_states.dtype
76+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
77+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
78+
79+
if self.weight is not None:
80+
hidden_states = hidden_states * self.weight
81+
hidden_states = hidden_states.to(input_dtype)
82+
83+
return hidden_states
84+
85+
5786
class MochiLayerNormContinuous(nn.Module):
5887
def __init__(
5988
self,
@@ -139,10 +168,10 @@ def __init__(
139168

140169
self.heads = out_dim // dim_head if out_dim is not None else heads
141170

142-
self.norm_q = RMSNorm(dim_head, eps, True)
143-
self.norm_k = RMSNorm(dim_head, eps, True)
144-
self.norm_added_q = RMSNorm(dim_head, eps, True)
145-
self.norm_added_k = RMSNorm(dim_head, eps, True)
171+
self.norm_q = MochiRMSNorm(dim_head, eps, True)
172+
self.norm_k = MochiRMSNorm(dim_head, eps, True)
173+
self.norm_added_q = MochiRMSNorm(dim_head, eps, True)
174+
self.norm_added_k = MochiRMSNorm(dim_head, eps, True)
146175

147176
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
148177
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)

0 commit comments

Comments
 (0)