Skip to content

Commit 0609cd9

Browse files
author
Your Name
committed
issue/586 - 添加python的scaled_dot_product_attention的实现和测试
1 parent 2286cf7 commit 0609cd9

File tree

8 files changed

+335
-71
lines changed

8 files changed

+335
-71
lines changed

include/infinicore/ops/attention.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "../device.hpp"
44
#include "common/op.hpp"
5+
#include <optional>
56

67
namespace infinicore::op {
78
class Attention {
@@ -13,4 +14,15 @@ class Attention {
1314

1415
Tensor attention(Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos);
1516
void attention_(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos);
17+
18+
Tensor scaled_dot_product_attention(Tensor query,
19+
Tensor key,
20+
Tensor value,
21+
std::optional<float> scale);
22+
23+
void scaled_dot_product_attention_(Tensor out,
24+
Tensor query,
25+
Tensor key,
26+
Tensor value,
27+
std::optional<float> scale);
1628
} // namespace infinicore::op

python/infinicore/nn/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from infinicore.nn import (
2-
functional as functional,
3-
)
1+
from infinicore.nn import functional
2+
3+
__all__ = ["functional"]

python/infinicore/nn/functional.py

Lines changed: 0 additions & 67 deletions
This file was deleted.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .scaled_dot_product_attention import scaled_dot_product_attention
2+
3+
__all__ = ["scaled_dot_product_attention"]
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from typing import Optional
2+
3+
from infinicore.lib import _infinicore
4+
from infinicore.tensor import Tensor
5+
6+
__all__ = ["scaled_dot_product_attention"]
7+
8+
9+
def scaled_dot_product_attention(
10+
query: Tensor,
11+
key: Tensor,
12+
value: Tensor,
13+
attn_mask: Optional[Tensor] = None,
14+
dropout_p: float = 0.0,
15+
is_causal: bool = False,
16+
scale: Optional[float] = None,
17+
enable_gqa: bool = False,
18+
*,
19+
out=None,
20+
) -> Tensor:
21+
r"""Computes scaled dot product attention on query, key and value tensors."""
22+
23+
assert (attn_mask is None) and (0.0 == dropout_p), "Unsupported parameters."
24+
assert (enable_gqa is True) and (is_causal is True), "Incorrect parameter value."
25+
26+
ntoken = query.shape[-2]
27+
total_token = key.shape[-2]
28+
29+
assert (1 == ntoken and total_token > 1) or (ntoken == total_token), (
30+
"Incorrect parameter value."
31+
)
32+
33+
if out is None:
34+
return Tensor(
35+
_infinicore.scaled_dot_product_attention(
36+
query._underlying, key._underlying, value._underlying, scale
37+
)
38+
)
39+
40+
_infinicore.scaled_dot_product_attention_(
41+
out._underlying,
42+
query._underlying,
43+
key._underlying,
44+
value._underlying,
45+
scale,
46+
)
47+
48+
return out

src/infinicore/ops/attention/attention.cc

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "infinicore/ops/attention.hpp"
2-
2+
#include "infinicore/ops/causal_softmax.hpp"
3+
#include "infinicore/ops/gemm.hpp"
4+
#include <cmath>
35
namespace infinicore::op {
46

57
common::OpDispatcher<Attention::schema> &Attention::dispatcher() {
@@ -25,4 +27,89 @@ void attention_(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor
2527
Attention::execute(out, q, k, v, k_cache, v_cache, pos);
2628
}
2729

30+
Tensor scaled_dot_product_attention(Tensor query_states, // [bs, num_attention_heads, ntoken, head_dim]
31+
Tensor key_states, // [bs, num_key_value_heads, total_token, head_dim]
32+
Tensor value_states, // [bs, num_key_value_heads, total_token, head_dim]
33+
std::optional<float> scale) {
34+
35+
auto query_shape = query_states->shape();
36+
auto key_shape = key_states->shape();
37+
38+
Size batch_size = query_shape[0];
39+
Size num_attention_heads = query_shape[1];
40+
Size ntoken = query_shape[2];
41+
Size head_dim = key_shape[3];
42+
43+
Tensor output_values = Tensor::empty({batch_size, num_attention_heads, ntoken, head_dim}, query_states->dtype(), query_states->device());
44+
45+
scaled_dot_product_attention_(output_values, query_states, key_states, value_states, scale);
46+
47+
return output_values;
48+
}
49+
50+
void scaled_dot_product_attention_(Tensor out,
51+
Tensor query_states,
52+
Tensor key_states,
53+
Tensor value_states,
54+
std::optional<float> scale) {
55+
56+
auto query_shape = query_states->shape();
57+
auto key_shape = key_states->shape();
58+
59+
Size batch_size = query_shape[0];
60+
Size num_attention_heads = query_shape[1];
61+
Size ntoken = query_shape[2];
62+
63+
Size num_key_value_heads = key_shape[1];
64+
Size total_token = key_shape[2];
65+
Size head_dim = key_shape[3];
66+
67+
assert(0 == (num_attention_heads % num_key_value_heads));
68+
Size ngroup = num_attention_heads / num_key_value_heads;
69+
70+
float attention_scale{0.0f};
71+
if (scale.has_value()) {
72+
attention_scale = scale.value();
73+
} else {
74+
attention_scale = 1.f / float(sqrt(head_dim));
75+
}
76+
77+
Tensor out_view = out->view({batch_size, num_key_value_heads, ngroup * ntoken, head_dim});
78+
for (Size ib = 0; ib < batch_size; ++ib) {
79+
Tensor q = query_states->narrow({{0, ib, 1}})->view({num_attention_heads, ntoken, head_dim}); // [ num_attention_heads, ntoken, head_dim]
80+
Tensor k = key_states->narrow({{0, ib, 1}})->view({num_key_value_heads, total_token, head_dim}); // [ num_key_value_heads, total_token, head_dim]
81+
Tensor v = value_states->narrow({{0, ib, 1}})->view({num_key_value_heads, total_token, head_dim}); // [ num_key_value_heads, total_token, head_dim]
82+
Tensor output_v = out_view->narrow({{0, ib, 1}})->view({num_key_value_heads, ngroup * ntoken, head_dim});
83+
{
84+
/*
85+
输入:
86+
q, [ num_attention_heads, ntoken, head_dim]
87+
k, [ num_key_value_heads, total_token, head_dim]
88+
v, [ num_key_value_heads, total_token, head_dim]
89+
输出:
90+
att_val : {num_key_value_heads, ngroup * ntok, head_dim}
91+
*/
92+
93+
auto q_gemm = q->view({num_key_value_heads, ngroup * ntoken, head_dim}); // => {nkvh, ngroup * seq_len, dh}
94+
auto k_gemm = k->permute({0, 2, 1}); // => { nkvh, dh, total_token}
95+
auto v_gemm = v; // => { nkvh, total_token, dh}
96+
97+
// qk_score : => {nkvh, ngroup * ntoken, total_token}
98+
Tensor qk_score = gemm(q_gemm, // {nkvh, ngroup * ntoken, dh}
99+
k_gemm, // {nkvh, dh, total_token}
100+
attention_scale, 0.f);
101+
102+
// softmax
103+
104+
auto qk_softmax = qk_score->view({num_attention_heads, ntoken, total_token});
105+
causal_softmax_(qk_softmax, qk_softmax);
106+
107+
// values
108+
gemm_(output_v, // {nkvh, ngroup * ntoken, dh}
109+
qk_score, // {nkvh, ngroup * ntoken, total_token}
110+
v_gemm, // { nkvh, total_token, dh}
111+
1.0f, 0.0f);
112+
}
113+
}
114+
}
28115
} // namespace infinicore::op

src/infinicore/pybind11/ops/attention.hpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,32 @@
66

77
namespace py = pybind11;
88

9+
namespace infinicore::ops::attention {
10+
Tensor py_scaled_dot_product_attention(Tensor query,
11+
Tensor key,
12+
Tensor value,
13+
pybind11::object scale) {
14+
std::optional<float> scale_float = std::nullopt;
15+
if (!scale.is_none()) {
16+
scale_float = scale.cast<float>();
17+
}
18+
return op::scaled_dot_product_attention(query, key, value, scale_float);
19+
}
20+
21+
void py_scaled_dot_product_attention_(Tensor out,
22+
Tensor query,
23+
Tensor key,
24+
Tensor value,
25+
pybind11::object scale) {
26+
std::optional<float> scale_float = std::nullopt;
27+
if (!scale.is_none()) {
28+
scale_float = scale.cast<float>();
29+
}
30+
op::scaled_dot_product_attention_(out, query, key, value, scale_float);
31+
}
32+
33+
} // namespace infinicore::ops::attention
34+
935
namespace infinicore::ops {
1036

1137
inline void bind_attention(py::module &m) {
@@ -51,6 +77,23 @@ inline void bind_attention(py::module &m) {
5177
v_cache: Value cache tensor
5278
pos: Current position in the sequence
5379
)doc");
80+
81+
m.def("scaled_dot_product_attention",
82+
&attention::py_scaled_dot_product_attention,
83+
py::arg("query"),
84+
py::arg("key"),
85+
py::arg("value"),
86+
py::arg("scale") = py::none(),
87+
R"doc(Computes scaled dot product attention on query, key and value tensors)doc");
88+
89+
m.def("scaled_dot_product_attention_",
90+
&attention::py_scaled_dot_product_attention_,
91+
py::arg("out"),
92+
py::arg("query"),
93+
py::arg("key"),
94+
py::arg("value"),
95+
py::arg("scale") = py::none(),
96+
R"doc(In-place, Computes scaled dot product attention on query, key and value tensors)doc");
5497
}
5598

5699
} // namespace infinicore::ops

0 commit comments

Comments
 (0)