Skip to content

Commit 271d287

Browse files
author
pengcheng888
committed
issue/586 - 添加python的scaled_dot_product_attention的实现和测试
1 parent 2286cf7 commit 271d287

File tree

5 files changed

+296
-1
lines changed

5 files changed

+296
-1
lines changed

include/infinicore/ops/attention.hpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
#include "../device.hpp"
44
#include "common/op.hpp"
5-
5+
#include <pybind11/pybind11.h>
66
namespace infinicore::op {
77
class Attention {
88
public:
@@ -13,4 +13,15 @@ class Attention {
1313

1414
Tensor attention(Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos);
1515
void attention_(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos);
16+
17+
Tensor scaled_dot_product_attention(Tensor query,
18+
Tensor key,
19+
Tensor value,
20+
pybind11::object scale);
21+
22+
void scaled_dot_product_attention_(Tensor out,
23+
Tensor query,
24+
Tensor key,
25+
Tensor value,
26+
pybind11::object scale);
1627
} // namespace infinicore::op

python/infinicore/nn/functional.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
import infinicore
24
from infinicore.lib import _infinicore
35
from infinicore.tensor import Tensor
@@ -65,3 +67,44 @@ def swiglu(input: Tensor, other: Tensor, *, out=None):
6567
_infinicore.swiglu_(out._underlying, input._underlying, other._underlying)
6668

6769
return out
70+
71+
72+
def scaled_dot_product_attention(
73+
query: Tensor,
74+
key: Tensor,
75+
value: Tensor,
76+
attn_mask: Optional[Tensor] = None,
77+
dropout_p: float = 0.0,
78+
is_causal: bool = False,
79+
scale: Optional[float] = None,
80+
enable_gqa: bool = False,
81+
*,
82+
out=None,
83+
) -> Tensor:
84+
r"""Computes scaled dot product attention on query, key and value tensors."""
85+
assert (attn_mask is None) and (0.0 == dropout_p), "Unsupported parameters."
86+
assert (enable_gqa is True) and (is_causal is True), "Incorrect parameter value."
87+
88+
ntoken = query.shape[-2]
89+
total_token = key.shape[-2]
90+
91+
assert (1 == ntoken and total_token > 1) or (ntoken == total_token), (
92+
"Incorrect parameter value."
93+
)
94+
95+
if out is None:
96+
return infinicore.Tensor(
97+
_infinicore.scaled_dot_product_attention(
98+
query._underlying, key._underlying, value._underlying, scale
99+
)
100+
)
101+
102+
_infinicore.scaled_dot_product_attention_(
103+
out._underlying,
104+
query._underlying,
105+
key._underlying,
106+
value._underlying,
107+
scale,
108+
)
109+
110+
return out

src/infinicore/ops/attention/attention.cc

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#include "infinicore/ops/attention.hpp"
2+
#include "infinicore/ops/causal_softmax.hpp"
3+
#include "infinicore/ops/gemm.hpp"
24

35
namespace infinicore::op {
46

@@ -25,4 +27,89 @@ void attention_(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor
2527
Attention::execute(out, q, k, v, k_cache, v_cache, pos);
2628
}
2729

30+
Tensor scaled_dot_product_attention(Tensor query_states, // [bs, num_attention_heads, ntoken, head_dim]
31+
Tensor key_states, // [bs, num_key_value_heads, total_token, head_dim]
32+
Tensor value_states, // [bs, num_key_value_heads, total_token, head_dim]
33+
pybind11::object scale) {
34+
35+
auto query_shape = query_states->shape();
36+
auto key_shape = key_states->shape();
37+
38+
Size batch_size = query_shape[0];
39+
Size num_attention_heads = query_shape[1];
40+
Size ntoken = query_shape[2];
41+
Size head_dim = key_shape[3];
42+
43+
Tensor output_values = Tensor::empty({batch_size, num_attention_heads, ntoken, head_dim}, query_states->dtype(), query_states->device());
44+
45+
scaled_dot_product_attention_(output_values, query_states, key_states, value_states, scale);
46+
47+
return output_values;
48+
}
49+
50+
void scaled_dot_product_attention_(Tensor out,
51+
Tensor query_states,
52+
Tensor key_states,
53+
Tensor value_states,
54+
pybind11::object scale) {
55+
56+
auto query_shape = query_states->shape();
57+
auto key_shape = key_states->shape();
58+
59+
Size batch_size = query_shape[0];
60+
Size num_attention_heads = query_shape[1];
61+
Size ntoken = query_shape[2];
62+
63+
Size num_key_value_heads = key_shape[1];
64+
Size total_token = key_shape[2];
65+
Size head_dim = key_shape[3];
66+
67+
assert(0 == (num_attention_heads % num_key_value_heads));
68+
Size ngroup = num_attention_heads / num_key_value_heads;
69+
70+
float attention_scale{0.0f};
71+
if (!scale.is_none()) {
72+
attention_scale = scale.cast<float>();
73+
} else {
74+
attention_scale = 1.f / float(sqrt(head_dim));
75+
}
76+
77+
Tensor out_view = out->view({batch_size, num_key_value_heads, ngroup * ntoken, head_dim});
78+
for (Size ib = 0; ib < batch_size; ++ib) {
79+
Tensor q = query_states->narrow({{0, ib, 1}})->view({num_attention_heads, ntoken, head_dim}); // [ num_attention_heads, ntoken, head_dim]
80+
Tensor k = key_states->narrow({{0, ib, 1}})->view({num_key_value_heads, total_token, head_dim}); // [ num_key_value_heads, total_token, head_dim]
81+
Tensor v = value_states->narrow({{0, ib, 1}})->view({num_key_value_heads, total_token, head_dim}); // [ num_key_value_heads, total_token, head_dim]
82+
Tensor output_v = out_view->narrow({{0, ib, 1}})->view({num_key_value_heads, ngroup * ntoken, head_dim});
83+
{
84+
/*
85+
输入:
86+
q, [ num_attention_heads, ntoken, head_dim]
87+
k, [ num_key_value_heads, total_token, head_dim]
88+
v, [ num_key_value_heads, total_token, head_dim]
89+
输出:
90+
att_val : {num_key_value_heads, ngroup * ntok, head_dim}
91+
*/
92+
93+
auto q_gemm = q->view({num_key_value_heads, ngroup * ntoken, head_dim}); // => {nkvh, ngroup * seq_len, dh}
94+
auto k_gemm = k->permute({0, 2, 1}); // => { nkvh, dh, total_token}
95+
auto v_gemm = v; // => { nkvh, total_token, dh}
96+
97+
// qk_score : => {nkvh, ngroup * ntoken, total_token}
98+
Tensor qk_score = gemm(q_gemm, // {nkvh, ngroup * ntoken, dh}
99+
k_gemm, // {nkvh, dh, total_token}
100+
attention_scale, 0.f);
101+
102+
// softmax
103+
104+
auto qk_softmax = qk_score->view({num_attention_heads, ntoken, total_token});
105+
causal_softmax_(qk_softmax, qk_softmax);
106+
107+
// values
108+
gemm_(output_v, // {nkvh, ngroup * ntoken, dh}
109+
qk_score, // {nkvh, ngroup * ntoken, total_token}
110+
v_gemm, // { nkvh, total_token, dh}
111+
1.0f, 0.0f);
112+
}
113+
}
114+
}
28115
} // namespace infinicore::op

src/infinicore/pybind11/ops/attention.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,23 @@ inline void bind_attention(py::module &m) {
5151
v_cache: Value cache tensor
5252
pos: Current position in the sequence
5353
)doc");
54+
55+
m.def("scaled_dot_product_attention",
56+
&op::scaled_dot_product_attention,
57+
py::arg("query"),
58+
py::arg("key"),
59+
py::arg("value"),
60+
py::arg("scale") = py::none(),
61+
R"doc(Computes scaled dot product attention on query, key and value tensors)doc");
62+
63+
m.def("scaled_dot_product_attention_",
64+
&op::scaled_dot_product_attention_,
65+
py::arg("out"),
66+
py::arg("query"),
67+
py::arg("key"),
68+
py::arg("value"),
69+
py::arg("scale") = py::none(),
70+
R"doc(In-place, Computes scaled dot product attention on query, key and value tensors)doc");
5471
}
5572

5673
} // namespace infinicore::ops
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import sys
2+
import os
3+
4+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
5+
6+
import torch
7+
import infinicore
8+
from framework.base import BaseOperatorTest, TensorSpec, TestCase
9+
from framework.runner import GenericTestRunner
10+
from framework.utils import is_broadcast
11+
12+
13+
# ==============================================================================
14+
# Operator-specific configuration
15+
# ==============================================================================
16+
_TEST_CASES_DATA = [
17+
# bs, ntoken, total_token, num_attention_heads, num_key_value_heads, head_dim
18+
(1, 5, 5, 32, 8, 64),
19+
(1, 1, 5, 32, 8, 64),
20+
(4, 16, 16, 32, 8, 64),
21+
(4, 1, 128, 32, 8, 64),
22+
]
23+
24+
25+
# Tolerance configuration
26+
_TOLERANCE_MAP = {
27+
infinicore.float16: {"atol": 1e-2, "rtol": 1e-2},
28+
infinicore.float32: {"atol": 1e-2, "rtol": 1e-2},
29+
infinicore.bfloat16: {"atol": 5e-2, "rtol": 5e-2},
30+
}
31+
32+
33+
# Data types to test
34+
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
35+
36+
37+
def parse_test_cases():
38+
"""
39+
Parse test case data and return list of TestCase objects for sdpa operation.
40+
Each test case contains all necessary information for execution and validation.
41+
"""
42+
test_cases = []
43+
44+
for data in _TEST_CASES_DATA:
45+
bs = data[0]
46+
ntoken, total_token = data[1], data[2]
47+
num_attention_heads, num_key_value_heads = data[3], data[4]
48+
head_dim = data[5]
49+
50+
# Determine shapes based on batch dimension
51+
query_shape = (bs, num_attention_heads, ntoken, head_dim)
52+
key_shape = (bs, num_key_value_heads, total_token, head_dim)
53+
value_shape = (bs, num_key_value_heads, total_token, head_dim)
54+
out_shape = (bs, num_attention_heads, ntoken, head_dim)
55+
56+
# Check if tensors support in-place operations
57+
c_supports_inplace = not is_broadcast(out_shape)
58+
59+
# Generate test cases for all data types
60+
for dtype in _TENSOR_DTYPES:
61+
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3})
62+
63+
# Create typed tensor specs
64+
query_spec = TensorSpec.from_tensor(query_shape, None, dtype)
65+
key_spec = TensorSpec.from_tensor(key_shape, None, dtype)
66+
value_spec = TensorSpec.from_tensor(value_shape, None, dtype)
67+
out_spec = TensorSpec.from_tensor(out_shape, None, dtype)
68+
69+
# Test Case 1: Out-of-place (return value)
70+
test_cases.append(
71+
TestCase(
72+
inputs=[query_spec, key_spec, value_spec],
73+
kwargs={},
74+
output_spec=None,
75+
comparison_target=None,
76+
tolerance=tolerance,
77+
description=f"sdpa - OUT_OF_PLACE",
78+
)
79+
)
80+
81+
# Test Case 2: In-place with explicit output tensor
82+
if c_supports_inplace:
83+
test_cases.append(
84+
TestCase(
85+
inputs=[query_spec, key_spec, value_spec],
86+
kwargs=None,
87+
output_spec=out_spec, # Specify the output tensor spec
88+
comparison_target="out",
89+
tolerance=tolerance,
90+
description=f"sdpa - INPLACE(out)",
91+
)
92+
)
93+
94+
return test_cases
95+
96+
97+
class OpTest(BaseOperatorTest):
98+
"""sdpa operator test with simplified implementation"""
99+
100+
def __init__(self):
101+
super().__init__("sdpa")
102+
103+
def get_test_cases(self):
104+
return parse_test_cases()
105+
106+
def torch_operator(self, query, key, value, out=None, **kwargs):
107+
"""PyTorch sdpa implementation"""
108+
ntoken = query.shape[-2]
109+
total_token = key.shape[-2]
110+
111+
is_causal = True
112+
if 1 == ntoken and total_token > 1:
113+
is_causal = False
114+
115+
result = torch.nn.functional.scaled_dot_product_attention(
116+
query, key, value, is_causal=is_causal, enable_gqa=True
117+
)
118+
if out is not None:
119+
out.copy_(result)
120+
return out
121+
return result
122+
123+
def infinicore_operator(self, query, key, value, out=None, **kwargs):
124+
"""InfiniCore sdpa implementation"""
125+
return infinicore.nn.functional.scaled_dot_product_attention(
126+
query, key, value, is_causal=True, enable_gqa=True, out=out
127+
)
128+
129+
130+
def main():
131+
"""Main entry point"""
132+
runner = GenericTestRunner(OpTest)
133+
runner.run_and_exit()
134+
135+
136+
if __name__ == "__main__":
137+
main()

0 commit comments

Comments
 (0)