Skip to content

Commit b255e9a

Browse files
committed
feat: refactor kernel dir and add flashinfer for cuda kernel.
1 parent eda1805 commit b255e9a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1211
-330
lines changed

.gitmodules

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,9 @@
2828
[submodule "third_party/Mooncake"]
2929
path = third_party/Mooncake
3030
url = https://github.com/kvcache-ai/Mooncake.git
31+
[submodule "third_party/flashinfer"]
32+
path = third_party/flashinfer
33+
url = https://gitcode.com/xLLM-AI/flashinfer.git
34+
[submodule "third_party/cutlass"]
35+
path = third_party/cutlass
36+
url = https://gitcode.com/xLLM-AI/cutlass.git

third_party/CMakeLists.txt

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,24 @@ target_include_directories(mooncake_store PUBLIC
2020
)
2121

2222
target_link_libraries(mooncake_store PUBLIC transfer_engine cachelib_memory_allocator)
23+
24+
25+
if(USE_CUDA)
26+
cc_library(
27+
NAME
28+
cutlass
29+
INCLUDES
30+
cutlass/include
31+
cutlass/tools/util/include
32+
DEPS
33+
torch # TODO: depends on CUDA instead of torch
34+
)
35+
cc_library(
36+
NAME
37+
flashinfer
38+
INCLUDES
39+
flashinfer/include
40+
DEPS
41+
cutlass
42+
)
43+
endif()

third_party/cutlass

Submodule cutlass added at e6e2cc2

third_party/flashinfer

Submodule flashinfer added at bd98dac

xllm/core/kernels/CMakeLists.txt

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,28 @@
11
include(cc_library)
22

33
if(USE_NPU)
4-
include_directories(
5-
${CMAKE_SOURCE_DIR}/third_party/spdlog/include
6-
)
74
add_subdirectory(npu)
85
endif()
96

107
if(USE_MLU)
118
add_subdirectory(mlu)
129
endif()
10+
11+
if(USE_CUDA)
12+
add_subdirectory(cuda)
13+
endif()
14+
15+
cc_library(
16+
NAME
17+
kernels
18+
HDRS
19+
param.h
20+
torch_ops_api.h
21+
SRCS
22+
torch_ops_api.cpp
23+
DEPS
24+
torch
25+
$<$<BOOL:${USE_NPU}>:npu_kernels>
26+
$<$<BOOL:${USE_MLU}>:mlu_kernels>
27+
$<$<BOOL:${USE_CUDA}>:cuda_kernels>
28+
)

xllm/core/kernels/mlu/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ include(cc_library)
22

33
file(GLOB_RECURSE MLU_HEADER_FILES
44
"${CMAKE_CURRENT_LIST_DIR}/*.h"
5-
"${CMAKE_CURRENT_LIST_DIR}/*.hpp"
65
)
76

87
file(GLOB_RECURSE MLU_SOURCE_FILES

xllm/core/kernels/mlu/active.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "mlu_ops_api.h"
17+
#include "torch_mlu_ops.h"
18+
19+
namespace xllm::mlu {
20+
21+
void active(const torch::Tensor& input,
22+
torch::Tensor& output,
23+
const std::optional<torch::Tensor>& bias,
24+
const std::optional<torch::Tensor>& cusum_token_count,
25+
const std::string& act_mode,
26+
bool is_gated,
27+
int start_expert_id,
28+
int expert_size) {
29+
tmo::torch_api::active(input,
30+
output,
31+
bias,
32+
cusum_token_count,
33+
act_mode,
34+
is_gated,
35+
start_expert_id,
36+
expert_size);
37+
}
38+
} // namespace xllm::mlu
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "mlu_ops_api.h"
17+
#include "torch_mlu_ops.h"
18+
19+
namespace xllm::mlu {
20+
21+
void reshape_paged_cache(const torch::Tensor& key,
22+
const torch::Tensor& value,
23+
torch::Tensor& k_cache,
24+
torch::Tensor& v_cache,
25+
const torch::Tensor& slot_mapping,
26+
bool direction) {
27+
tmo::torch_api::reshape_paged_cache(
28+
key, value, k_cache, v_cache, slot_mapping, direction);
29+
}
30+
31+
void flash_attention(const torch::Tensor& query,
32+
const torch::Tensor& key,
33+
const torch::Tensor& value,
34+
torch::Tensor& output,
35+
torch::Tensor& output_lse,
36+
int query_start_loc,
37+
int seq_start_loc,
38+
const std::optional<torch::Tensor>& alibi_slope,
39+
const std::optional<torch::Tensor>& attn_bias,
40+
const std::optional<torch::Tensor>& q_quant_scale,
41+
const std::optional<torch::Tensor>& k_quant_scale,
42+
const std::optional<torch::Tensor>& v_quant_scale,
43+
const std::optional<torch::Tensor>& out_quant_scale,
44+
const std::optional<torch::Tensor>& block_tables,
45+
int max_query_len,
46+
int max_seq_len,
47+
float scale,
48+
bool is_causal,
49+
int window_size_left,
50+
int window_size_right,
51+
const std::string& compute_dtype,
52+
bool return_lse) {
53+
tmo::torch_api::flash_attention(query,
54+
key,
55+
value,
56+
output,
57+
output_lse,
58+
query_start_loc,
59+
seq_start_loc,
60+
alibi_slope,
61+
attn_bias,
62+
q_quant_scale,
63+
k_quant_scale,
64+
v_quant_scale,
65+
out_quant_scale,
66+
block_tables,
67+
max_query_len,
68+
max_seq_len,
69+
scale,
70+
is_causal,
71+
window_size_left,
72+
window_size_right,
73+
compute_dtype,
74+
return_lse);
75+
}
76+
77+
void single_query_cached_kv_attn(
78+
const torch::Tensor& query,
79+
const torch::Tensor& k_cache,
80+
torch::Tensor& output,
81+
const torch::Tensor& block_table,
82+
const torch::Tensor& seq_lens,
83+
const torch::Tensor& v_cache,
84+
torch::Tensor& output_lse,
85+
const std::optional<torch::Tensor>& q_quant_scale,
86+
const std::optional<torch::Tensor>& k_cache_quant_scale,
87+
const std::optional<torch::Tensor>& v_cache_quant_scale,
88+
const std::optional<torch::Tensor>& out_quant_scale,
89+
const std::optional<torch::Tensor>& alibi_slope,
90+
const std::optional<torch::Tensor>& mask,
91+
const std::string& compute_dtype,
92+
int max_seq_len,
93+
int window_size_left,
94+
int window_size_right,
95+
float scale,
96+
bool return_lse,
97+
int kv_cache_quant_bit_size) {
98+
tmo::torch_api::single_query_cached_kv_attn(query,
99+
k_cache,
100+
output,
101+
block_table,
102+
seq_lens,
103+
v_cache,
104+
output_lse,
105+
q_quant_scale,
106+
k_cache_quant_scale,
107+
v_cache_quant_scale,
108+
out_quant_scale,
109+
alibi_slope,
110+
mask,
111+
compute_dtype,
112+
max_seq_len,
113+
window_size_left,
114+
window_size_right,
115+
scale,
116+
return_lse,
117+
kv_cache_quant_bit_size);
118+
}
119+
120+
} // namespace xllm::mlu
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "mlu_ops_api.h"
17+
#include "torch_mlu_ops.h"
18+
19+
namespace xllm::mlu {
20+
21+
void fused_layernorm(const torch::Tensor& input,
22+
torch::Tensor& output,
23+
const std::optional<torch::Tensor>& residual,
24+
const torch::Tensor& weight,
25+
const std::optional<torch::Tensor>& beta,
26+
const std::optional<torch::Tensor>& bias,
27+
const std::optional<torch::Tensor>& quant_scale,
28+
const std::optional<torch::Tensor>& residual_out,
29+
const std::optional<torch::Tensor>& smooth_quant_scale,
30+
const std::optional<torch::Tensor>& normed_out,
31+
const std::string& mode,
32+
double eps,
33+
bool store_output_before_norm,
34+
bool store_output_after_norm,
35+
bool dynamic_quant) {
36+
tmo::torch_api::fused_layernorm(input,
37+
output,
38+
residual,
39+
weight,
40+
beta,
41+
bias,
42+
quant_scale,
43+
residual_out,
44+
smooth_quant_scale,
45+
normed_out,
46+
mode,
47+
eps,
48+
store_output_before_norm,
49+
store_output_after_norm,
50+
dynamic_quant);
51+
}
52+
53+
} // namespace xllm::mlu

xllm/core/kernels/mlu/fused_moe.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16+
#include "mlu_ops_api.h"
1617
#include "torch_mlu_ops.h"
17-
#include "torch_ops_api.h"
1818

1919
namespace {
2020
torch::Tensor create_group_gemm_output(const torch::Tensor& a,
@@ -28,6 +28,7 @@ torch::Tensor create_group_gemm_output(const torch::Tensor& a,
2828
} // namespace
2929

3030
namespace xllm::mlu {
31+
3132
torch::Tensor fused_moe(const torch::Tensor& hidden_states,
3233
const torch::Tensor& gating_output,
3334
const torch::Tensor& w1,

0 commit comments

Comments
 (0)