11#include " infinicore/ops/attention.hpp"
2-
2+ #include " infinicore/ops/causal_softmax.hpp"
3+ #include " infinicore/ops/gemm.hpp"
4+ #include < cmath>
35namespace infinicore ::op {
46
57common::OpDispatcher<Attention::schema> &Attention::dispatcher () {
@@ -25,4 +27,89 @@ void attention_(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor
2527 Attention::execute (out, q, k, v, k_cache, v_cache, pos);
2628}
2729
30+ Tensor scaled_dot_product_attention (Tensor query_states, // [bs, num_attention_heads, ntoken, head_dim]
31+ Tensor key_states, // [bs, num_key_value_heads, total_token, head_dim]
32+ Tensor value_states, // [bs, num_key_value_heads, total_token, head_dim]
33+ std::optional<float > scale) {
34+
35+ auto query_shape = query_states->shape ();
36+ auto key_shape = key_states->shape ();
37+
38+ Size batch_size = query_shape[0 ];
39+ Size num_attention_heads = query_shape[1 ];
40+ Size ntoken = query_shape[2 ];
41+ Size head_dim = key_shape[3 ];
42+
43+ Tensor output_values = Tensor::empty ({batch_size, num_attention_heads, ntoken, head_dim}, query_states->dtype (), query_states->device ());
44+
45+ scaled_dot_product_attention_ (output_values, query_states, key_states, value_states, scale);
46+
47+ return output_values;
48+ }
49+
50+ void scaled_dot_product_attention_ (Tensor out,
51+ Tensor query_states,
52+ Tensor key_states,
53+ Tensor value_states,
54+ std::optional<float > scale) {
55+
56+ auto query_shape = query_states->shape ();
57+ auto key_shape = key_states->shape ();
58+
59+ Size batch_size = query_shape[0 ];
60+ Size num_attention_heads = query_shape[1 ];
61+ Size ntoken = query_shape[2 ];
62+
63+ Size num_key_value_heads = key_shape[1 ];
64+ Size total_token = key_shape[2 ];
65+ Size head_dim = key_shape[3 ];
66+
67+ assert (0 == (num_attention_heads % num_key_value_heads));
68+ Size ngroup = num_attention_heads / num_key_value_heads;
69+
70+ float attention_scale{0 .0f };
71+ if (scale.has_value ()) {
72+ attention_scale = scale.value ();
73+ } else {
74+ attention_scale = 1 .f / float (sqrt (head_dim));
75+ }
76+
77+ Tensor out_view = out->view ({batch_size, num_key_value_heads, ngroup * ntoken, head_dim});
78+ for (Size ib = 0 ; ib < batch_size; ++ib) {
79+ Tensor q = query_states->narrow ({{0 , ib, 1 }})->view ({num_attention_heads, ntoken, head_dim}); // [ num_attention_heads, ntoken, head_dim]
80+ Tensor k = key_states->narrow ({{0 , ib, 1 }})->view ({num_key_value_heads, total_token, head_dim}); // [ num_key_value_heads, total_token, head_dim]
81+ Tensor v = value_states->narrow ({{0 , ib, 1 }})->view ({num_key_value_heads, total_token, head_dim}); // [ num_key_value_heads, total_token, head_dim]
82+ Tensor output_v = out_view->narrow ({{0 , ib, 1 }})->view ({num_key_value_heads, ngroup * ntoken, head_dim});
83+ {
84+ /*
85+ 输入:
86+ q, [ num_attention_heads, ntoken, head_dim]
87+ k, [ num_key_value_heads, total_token, head_dim]
88+ v, [ num_key_value_heads, total_token, head_dim]
89+ 输出:
90+ att_val : {num_key_value_heads, ngroup * ntok, head_dim}
91+ */
92+
93+ auto q_gemm = q->view ({num_key_value_heads, ngroup * ntoken, head_dim}); // => {nkvh, ngroup * seq_len, dh}
94+ auto k_gemm = k->permute ({0 , 2 , 1 }); // => { nkvh, dh, total_token}
95+ auto v_gemm = v; // => { nkvh, total_token, dh}
96+
97+ // qk_score : => {nkvh, ngroup * ntoken, total_token}
98+ Tensor qk_score = gemm (q_gemm, // {nkvh, ngroup * ntoken, dh}
99+ k_gemm, // {nkvh, dh, total_token}
100+ attention_scale, 0 .f );
101+
102+ // softmax
103+
104+ auto qk_softmax = qk_score->view ({num_attention_heads, ntoken, total_token});
105+ causal_softmax_ (qk_softmax, qk_softmax);
106+
107+ // values
108+ gemm_ (output_v, // {nkvh, ngroup * ntoken, dh}
109+ qk_score, // {nkvh, ngroup * ntoken, total_token}
110+ v_gemm, // { nkvh, total_token, dh}
111+ 1 .0f , 0 .0f );
112+ }
113+ }
114+ }
28115} // namespace infinicore::op
0 commit comments