Skip to content

Commit cc7b91d

Browse files
committed
update
1 parent 11ce6b8 commit cc7b91d

File tree

2 files changed

+1
-3
lines changed

2 files changed

+1
-3
lines changed

src/diffusers/models/embeddings.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1594,13 +1594,12 @@ def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.
15941594
Returns:
15951595
pooled: (B, D) tensor of pooled tokens.
15961596
"""
1597-
input_dtype = x.dtype
15981597
assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens.
15991598
assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
16001599
mask = mask[:, :, None].to(dtype=x.dtype)
16011600
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
16021601
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
1603-
return pooled.to(input_dtype)
1602+
return pooled
16041603

16051604
def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
16061605
r"""

src/diffusers/models/transformers/transformer_mochi.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ def forward(
136136

137137
emb = self.linear(self.silu(emb))
138138
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
139-
140139
hidden_states = self.norm(hidden_states.to(torch.float32)) * (1 + scale_msa[:, None].to(torch.float32))
141140
hidden_states = hidden_states.to(hidden_states_dtype)
142141

0 commit comments

Comments
 (0)