Skip to content

Commit c814ead

Browse files
guanbaoyguangyey
andauthored
[Fix] fix sdpa_math with scale input (#4993) (#5002)
* fix sdpa_math with scale input * add ut * fix flake8 --------- Co-authored-by: guangyey <[email protected]>
1 parent 584f689 commit c814ead

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

csrc/gpu/aten/operators/transformers/sdp_utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ inline c10::SymFloat calculate_scale(
6161
const at::Tensor& query,
6262
c10::optional<double> scale) {
6363
const auto softmax_scale = scale.has_value()
64-
? scale.value()
64+
? (c10::SymFloat(1.0) / scale.value())
6565
: c10::SymFloat(query.sym_size(-1)).sqrt();
6666
return c10::SymFloat(softmax_scale);
6767
}

tests/gpu/examples/test_sdp.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,28 @@ def test_sdp_math_half(self, dtype=torch.float16):
4343
out_xpu = F.scaled_dot_product_attention(query.xpu(), key.xpu(), value.xpu())
4444

4545
self.assertEqual(out_cpu, out_xpu.cpu().float(), atol=1e-3, rtol=1e-3)
46+
47+
def test_sdp_math_fp32(self, dtype=torch.float):
48+
head_dim = 256
49+
seq_lenth = 1
50+
k_seq_lenth = 33
51+
v_seq_lenth = 33
52+
scale = head_dim**-0.5
53+
query = torch.rand(1, 16, seq_lenth, head_dim, dtype=dtype)
54+
key = torch.rand(1, 16, k_seq_lenth, head_dim, dtype=dtype)
55+
value = torch.rand(1, 16, v_seq_lenth, head_dim, dtype=dtype)
56+
57+
out_cpu = F.scaled_dot_product_attention(
58+
query.float(),
59+
key.float(),
60+
value.float(),
61+
scale=scale,
62+
)
63+
out_xpu = F.scaled_dot_product_attention(
64+
query.xpu(),
65+
key.xpu(),
66+
value.xpu(),
67+
scale=scale,
68+
)
69+
70+
self.assertEqual(out_cpu, out_xpu.cpu().float(), atol=1e-3, rtol=1e-3)

0 commit comments

Comments
 (0)