1+ /* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+ Licensed under the Apache License, Version 2.0 (the "License");
4+ you may not use this file except in compliance with the License.
5+ You may obtain a copy of the License at
6+
7+ https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+ Unless required by applicable law or agreed to in writing, software
10+ distributed under the License is distributed on an "AS IS" BASIS,
11+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ See the License for the specific language governing permissions and
13+ limitations under the License.
14+ ==============================================================================*/
15+
16+ #include " mlu_ops_api.h"
17+ #include " torch_mlu_ops.h"
18+
19+ namespace xllm ::mlu {
20+
21+ void reshape_paged_cache (const torch::Tensor& key,
22+ const torch::Tensor& value,
23+ torch::Tensor& k_cache,
24+ torch::Tensor& v_cache,
25+ const torch::Tensor& slot_mapping,
26+ bool direction) {
27+ tmo::torch_api::reshape_paged_cache (
28+ key, value, k_cache, v_cache, slot_mapping, direction);
29+ }
30+
31+ void flash_attention (const torch::Tensor& query,
32+ const torch::Tensor& key,
33+ const torch::Tensor& value,
34+ torch::Tensor& output,
35+ torch::Tensor& output_lse,
36+ int query_start_loc,
37+ int seq_start_loc,
38+ const std::optional<torch::Tensor>& alibi_slope,
39+ const std::optional<torch::Tensor>& attn_bias,
40+ const std::optional<torch::Tensor>& q_quant_scale,
41+ const std::optional<torch::Tensor>& k_quant_scale,
42+ const std::optional<torch::Tensor>& v_quant_scale,
43+ const std::optional<torch::Tensor>& out_quant_scale,
44+ const std::optional<torch::Tensor>& block_tables,
45+ int max_query_len,
46+ int max_seq_len,
47+ float scale,
48+ bool is_causal,
49+ int window_size_left,
50+ int window_size_right,
51+ const std::string& compute_dtype,
52+ bool return_lse) {
53+ tmo::torch_api::flash_attention (query,
54+ key,
55+ value,
56+ output,
57+ output_lse,
58+ query_start_loc,
59+ seq_start_loc,
60+ alibi_slope,
61+ attn_bias,
62+ q_quant_scale,
63+ k_quant_scale,
64+ v_quant_scale,
65+ out_quant_scale,
66+ block_tables,
67+ max_query_len,
68+ max_seq_len,
69+ scale,
70+ is_causal,
71+ window_size_left,
72+ window_size_right,
73+ compute_dtype,
74+ return_lse);
75+ }
76+
77+ void single_query_cached_kv_attn (
78+ const torch::Tensor& query,
79+ const torch::Tensor& k_cache,
80+ torch::Tensor& output,
81+ const torch::Tensor& block_table,
82+ const torch::Tensor& seq_lens,
83+ const torch::Tensor& v_cache,
84+ torch::Tensor& output_lse,
85+ const std::optional<torch::Tensor>& q_quant_scale,
86+ const std::optional<torch::Tensor>& k_cache_quant_scale,
87+ const std::optional<torch::Tensor>& v_cache_quant_scale,
88+ const std::optional<torch::Tensor>& out_quant_scale,
89+ const std::optional<torch::Tensor>& alibi_slope,
90+ const std::optional<torch::Tensor>& mask,
91+ const std::string& compute_dtype,
92+ int max_seq_len,
93+ int window_size_left,
94+ int window_size_right,
95+ float scale,
96+ bool return_lse,
97+ int kv_cache_quant_bit_size) {
98+ tmo::torch_api::single_query_cached_kv_attn (query,
99+ k_cache,
100+ output,
101+ block_table,
102+ seq_lens,
103+ v_cache,
104+ output_lse,
105+ q_quant_scale,
106+ k_cache_quant_scale,
107+ v_cache_quant_scale,
108+ out_quant_scale,
109+ alibi_slope,
110+ mask,
111+ compute_dtype,
112+ max_seq_len,
113+ window_size_left,
114+ window_size_right,
115+ scale,
116+ return_lse,
117+ kv_cache_quant_bit_size);
118+ }
119+
120+ } // namespace xllm::mlu
0 commit comments