Skip to content

Commit 0574ba5

Browse files
authored
fix(transformers): align SDPA with Transformers' logic (#1167)
* fix sdpa to align with Transformers' logic * speed up execution with `ops.speed_fusion_attention` * add backprop test * small fix
1 parent 090fcbb commit 0574ba5

File tree

3 files changed

+211
-15
lines changed

3 files changed

+211
-15
lines changed
Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
"""Adapted from https://github.com/huggingface/transformers/tree/main/src/transformers/integrations/sdpa_attention.py."""
2-
2+
from math import sqrt
33
from typing import Optional
44

55
import mindspore as ms
6-
from mindspore import mint, nn
6+
from mindspore import mint, nn, ops
7+
8+
from ..utils import logging
9+
10+
logger = logging.get_logger(__name__)
711

812

913
def repeat_kv(hidden_states: ms.Tensor, n_rep: int) -> ms.Tensor:
@@ -23,23 +27,53 @@ def sdpa_attention_forward(
2327
query: ms.Tensor,
2428
key: ms.Tensor,
2529
value: ms.Tensor,
26-
attention_mask: Optional[ms.Tensor],
30+
attention_mask: Optional[ms.Tensor] = None,
2731
dropout: float = 0.0,
2832
scaling: Optional[float] = None,
29-
is_causal: Optional[bool] = None, # to align with torch
33+
is_causal: Optional[bool] = None,
3034
**kwargs,
3135
) -> tuple[ms.Tensor, None]:
32-
key_states = repeat_kv(key, module.num_key_value_groups)
33-
value_states = repeat_kv(value, module.num_key_value_groups)
36+
if kwargs.get("output_attentions", False) or kwargs.get("head_mask", None) is not None:
37+
logger.warning_once(
38+
"`sdpa` attention does not support `output_attentions=True` or `head_mask`."
39+
" Please set your attention to `eager` if you want any of these features."
40+
)
41+
42+
if hasattr(module, "num_key_value_groups"):
43+
key = repeat_kv(key, module.num_key_value_groups)
44+
value = repeat_kv(value, module.num_key_value_groups)
45+
46+
if attention_mask is not None and attention_mask.ndim == 4:
47+
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
48+
49+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
50+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
51+
# Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool`
52+
if is_causal is None:
53+
# The last condition is for encoder (decoder) models which specify this by passing their own `is_causal` flag
54+
# This is mainly due to those models having mixed implementations for encoder, decoder, and encoder-decoder attns
55+
is_causal = query.shape[2] > 1 and attention_mask is None and getattr(module, "is_causal", True)
56+
57+
if is_causal:
58+
if attention_mask is not None:
59+
raise ValueError("Causal mode cannot be used with an explicit `attention_mask`")
60+
attention_mask = mint.ones((query.shape[-2], key.shape[-2]), dtype=ms.bool_).tril(diagonal=0)
61+
62+
if attention_mask is not None:
63+
attention_mask = mint.logical_not(attention_mask) # in MindSpore, 0 indicates retain, 1 indicates discard
3464

35-
attn_weights = mint.matmul(query, key_states.transpose(2, 3)) * scaling
36-
if attention_mask is not None and attention_mask.dim() == 4:
37-
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
38-
attn_weights = attn_weights + causal_mask
65+
scaling = 1 / sqrt(query.shape[-1]) if scaling is None else scaling
3966

40-
attn_weights = mint.nn.functional.softmax(attn_weights, dim=-1, dtype=ms.float32).to(query.dtype)
41-
attn_weights = mint.nn.functional.dropout(attn_weights, p=dropout, training=module.training)
42-
attn_output = mint.matmul(attn_weights, value_states)
43-
attn_output = attn_output.transpose(1, 2).contiguous()
67+
attn_output = ops.speed_fusion_attention(
68+
query,
69+
key,
70+
value,
71+
head_num=query.shape[1],
72+
input_layout="BNSD",
73+
atten_mask=attention_mask,
74+
scale=scaling,
75+
keep_prob=1 - dropout,
76+
)[0]
77+
attn_output = mint.transpose(attn_output, 1, 2).contiguous()
4478

45-
return attn_output, attn_weights
79+
return attn_output, None

tests/transformers_tests/integrations/__init__.py

Whitespace-only changes.
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import numpy as np
2+
import pytest
3+
import torch
4+
from transformers.integrations.sdpa_attention import sdpa_attention_forward as sdpa_attention_forward_transformers
5+
6+
from mindspore import grad, mint, set_seed, tensor
7+
8+
from mindone.transformers.integrations.sdpa_attention import sdpa_attention_forward
9+
from tests.modeling_test_utils import MS_DTYPE_MAPPING, PT_DTYPE_MAPPING
10+
11+
DTYPE_AND_THRESHOLDS = {"fp32": 1e-6, "fp16": 2e-3, "bf16": 2e-2}
12+
13+
14+
class MockAttentionModule:
15+
def __init__(self, is_causal):
16+
self.is_causal = is_causal
17+
self.training = False
18+
19+
20+
@pytest.fixture(scope="module")
21+
def q_k_v_target_mask() -> dict[str, tuple]:
22+
# B, H, S, D
23+
set_seed(42)
24+
q = np.random.uniform(size=(2, 8, 256, 32)).astype(np.float32)
25+
k = np.random.uniform(size=(2, 8, 256, 32)).astype(np.float32)
26+
v = np.random.uniform(size=(2, 8, 256, 32)).astype(np.float32)
27+
target = np.random.uniform(size=(2, 256, 8, 32)).astype(np.float32)
28+
attention_mask = np.random.randint(0, 2, (q.shape[0], 1, q.shape[2], k.shape[2]), dtype=bool)
29+
return {
30+
"ms": (tensor(q), tensor(k), tensor(v), tensor(target), tensor(attention_mask)),
31+
"pt": (torch.tensor(q), torch.tensor(k), torch.tensor(v), torch.tensor(target), torch.tensor(attention_mask)),
32+
}
33+
34+
35+
def cast_inputs(q, k, v, target, attention_mask, dtype):
36+
return (
37+
q.to(dtype),
38+
k.to(dtype),
39+
v.to(dtype),
40+
target.to(dtype),
41+
attention_mask, # no casting needed for attention_mask
42+
)
43+
44+
45+
@pytest.mark.parametrize("dtype", ["fp32", "fp16", "bf16"])
46+
@pytest.mark.parametrize("use_mask", [True, False], ids=["with_mask", "without_mask"])
47+
@pytest.mark.parametrize("jit", [False, True], ids=["eager", "jit"])
48+
def test_sdpa_attention_forward(q_k_v_target_mask, use_mask: bool, dtype: str, jit: bool):
49+
if jit:
50+
pytest.skip("`sdpa_attention_forward` can't be compiled with jit.")
51+
52+
module = MockAttentionModule(is_causal=False)
53+
54+
q, k, v, _, attn_mask = cast_inputs(*q_k_v_target_mask["ms"], dtype=MS_DTYPE_MAPPING[dtype])
55+
q_pt, k_pt, v_pt, _, attn_mask_pt = cast_inputs(*q_k_v_target_mask["pt"], dtype=PT_DTYPE_MAPPING[dtype])
56+
57+
output = sdpa_attention_forward(module, q, k, v, attention_mask=attn_mask if use_mask else None, is_causal=False)[0]
58+
output_pt = sdpa_attention_forward_transformers(
59+
module, q_pt, k_pt, v_pt, attention_mask=attn_mask_pt if use_mask else None, is_causal=False
60+
)[0]
61+
62+
assert output.shape == output_pt.shape, f"Shape mismatch: {output.shape} vs {output_pt.shape}"
63+
assert not output.isnan().any(), "Output contains NaNs."
64+
assert np.allclose(
65+
output.numpy().astype(np.float32), output_pt.to(torch.float32).numpy(), atol=DTYPE_AND_THRESHOLDS[dtype]
66+
)
67+
68+
69+
@pytest.mark.parametrize("dtype", ["fp32", "fp16", "bf16"])
70+
@pytest.mark.parametrize("use_mask", [True, False], ids=["with_mask", "without_mask"])
71+
@pytest.mark.parametrize("jit", [False, True], ids=["eager", "jit"])
72+
def test_sdpa_attention_backward(q_k_v_target_mask, use_mask: bool, dtype: str, jit: bool):
73+
if jit:
74+
pytest.skip("`sdpa_attention_forward` can't be compiled with jit.")
75+
76+
module = MockAttentionModule(is_causal=False)
77+
78+
# MindONE
79+
q, k, v, target, attn_mask = cast_inputs(*q_k_v_target_mask["ms"], dtype=MS_DTYPE_MAPPING[dtype])
80+
81+
def _forward(q_, k_, v_, target_):
82+
output = sdpa_attention_forward(
83+
module, q_, k_, v_, attention_mask=attn_mask if use_mask else None, is_causal=False
84+
)[0]
85+
return mint.nn.functional.mse_loss(output, target_)
86+
87+
grad_out = grad(_forward, grad_position=(0, 1, 2))(q, k, v, target)
88+
grad_out = mint.stack(grad_out, dim=0)
89+
90+
# Transformers
91+
q_pt, k_pt, v_pt, target_pt, attn_mask_pt = cast_inputs(*q_k_v_target_mask["pt"], dtype=PT_DTYPE_MAPPING[dtype])
92+
q_pt, k_pt, v_pt = q_pt.clone(), k_pt.clone(), v_pt.clone()
93+
q_pt.requires_grad, k_pt.requires_grad, v_pt.requires_grad = (True,) * 3
94+
95+
output_pt = sdpa_attention_forward_transformers(
96+
module, q_pt, k_pt, v_pt, attention_mask=attn_mask_pt if use_mask else None, is_causal=False
97+
)[0]
98+
loss = torch.nn.functional.mse_loss(output_pt, target_pt)
99+
loss.backward()
100+
grad_out_pt = torch.stack([q_pt.grad, k_pt.grad, v_pt.grad], dim=0)
101+
102+
assert grad_out.shape == grad_out_pt.shape, f"Shape mismatch: {grad_out.shape} vs {grad_out_pt.shape}"
103+
assert not grad_out.isnan().any(), "Output contains NaNs."
104+
assert np.allclose(
105+
grad_out.numpy().astype(np.float32), grad_out_pt.to(torch.float32).numpy(), atol=DTYPE_AND_THRESHOLDS[dtype]
106+
)
107+
108+
109+
@pytest.mark.parametrize("dtype", ["fp32", "fp16", "bf16"])
110+
@pytest.mark.parametrize("jit", [False, True], ids=["eager", "jit"])
111+
def test_sdpa_attention_causal_forward(q_k_v_target_mask, dtype: str, jit: bool):
112+
if jit:
113+
pytest.skip("`sdpa_attention_forward` can't be compiled with jit.")
114+
115+
module = MockAttentionModule(is_causal=True)
116+
117+
q, k, v, *_ = cast_inputs(*q_k_v_target_mask["ms"], dtype=MS_DTYPE_MAPPING[dtype])
118+
q_pt, k_pt, v_pt, *_ = cast_inputs(*q_k_v_target_mask["pt"], dtype=PT_DTYPE_MAPPING[dtype])
119+
120+
output = sdpa_attention_forward(module, q, k, v, attention_mask=None)[0]
121+
output_pt = sdpa_attention_forward_transformers(module, q_pt, k_pt, v_pt, attention_mask=None)[0]
122+
123+
assert output.shape == output_pt.shape, f"Shape mismatch: {output.shape} vs {output_pt.shape}"
124+
assert not output.isnan().any(), "Output contains NaNs."
125+
assert np.allclose(
126+
output.numpy().astype(np.float32), output_pt.to(torch.float32).numpy(), atol=DTYPE_AND_THRESHOLDS[dtype]
127+
)
128+
129+
130+
@pytest.mark.parametrize("dtype", ["fp32", "fp16", "bf16"])
131+
@pytest.mark.parametrize("jit", [False, True], ids=["eager", "jit"])
132+
def test_sdpa_attention_causal_backward(q_k_v_target_mask, dtype: str, jit: bool):
133+
if jit:
134+
pytest.skip("`sdpa_attention_forward` can't be compiled with jit.")
135+
136+
module = MockAttentionModule(is_causal=True)
137+
138+
# MindONE
139+
q, k, v, target, _ = cast_inputs(*q_k_v_target_mask["ms"], dtype=MS_DTYPE_MAPPING[dtype])
140+
141+
def _forward(q_, k_, v_, target_):
142+
output = sdpa_attention_forward(module, q_, k_, v_, attention_mask=None)[0]
143+
return mint.nn.functional.mse_loss(output, target_)
144+
145+
grad_out = grad(_forward, grad_position=(0, 1, 2))(q, k, v, target)
146+
grad_out = mint.stack(grad_out, dim=0)
147+
148+
# Transformers
149+
q_pt, k_pt, v_pt, target_pt, _ = cast_inputs(*q_k_v_target_mask["pt"], dtype=PT_DTYPE_MAPPING[dtype])
150+
q_pt, k_pt, v_pt = q_pt.clone(), k_pt.clone(), v_pt.clone()
151+
q_pt.requires_grad, k_pt.requires_grad, v_pt.requires_grad = (True,) * 3
152+
153+
output_pt = sdpa_attention_forward_transformers(module, q_pt, k_pt, v_pt, attention_mask=None)[0]
154+
loss = torch.nn.functional.mse_loss(output_pt, target_pt)
155+
loss.backward()
156+
grad_out_pt = torch.stack([q_pt.grad, k_pt.grad, v_pt.grad], dim=0)
157+
158+
assert grad_out.shape == grad_out_pt.shape, f"Shape mismatch: {grad_out.shape} vs {grad_out_pt.shape}"
159+
assert not grad_out.isnan().any(), "Output contains NaNs."
160+
assert np.allclose(
161+
grad_out.numpy().astype(np.float32), grad_out_pt.to(torch.float32).numpy(), atol=DTYPE_AND_THRESHOLDS[dtype]
162+
)

0 commit comments

Comments
 (0)