Skip to content

Commit bba83a4

Browse files
committed
fix SanaMultiscaleLinearAttention apply_quadratic_attention bf16
1 parent 477937e commit bba83a4

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -899,7 +899,7 @@ def apply_quadratic_attention(self, query: torch.Tensor, key: torch.Tensor, valu
899899
scores = torch.matmul(key.transpose(-1, -2), query)
900900
scores = scores.to(dtype=torch.float32)
901901
scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
902-
hidden_states = torch.matmul(value, scores)
902+
hidden_states = torch.matmul(value, scores.to(value.dtype))
903903
return hidden_states
904904

905905
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)