Skip to content

Commit e2ebf59

Browse files
add fa3 backend unfinished
1 parent 17b29a1 commit e2ebf59

File tree

5 files changed

+583
-3
lines changed

5 files changed

+583
-3
lines changed

csrc/register.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,7 @@ PYBIND11_MODULE(vortex_torch_C, m){
88
m.def("Chunkwise_NH2HN_Transpose", &Chunkwise_NH2HN_Transpose);
99
m.def("Chunkwise_HN2NH_Transpose", &Chunkwise_HN2NH_Transpose);
1010
m.def("topk_output", &topk_output);
11+
m.def("sglang_plan_decode_fa3", &sglang_plan_decode_fa3);
12+
m.def("sglang_plan_prefill_fa3", &sglang_plan_prefill_fa3);
13+
m.def("Chunkwise_HN2NH_Transpose_FA3", &Chunkwise_HN2NH_Transpose_FA3);
1114
}

csrc/register.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,48 @@ const int64_t max_seq_lengths
8686
);
8787

8888

89+
void sglang_plan_decode_fa3(
90+
const at::Tensor& cached_seq_lens,
91+
at::Tensor& dense_kv_indptr,
92+
at::Tensor& dense_kv_indices,
93+
at::Tensor& sparse_kv_indptr,
94+
at::Tensor& sparse_kv_indices,
95+
at::Tensor& dense_page_table,
96+
at::Tensor& dense_cache_seqlens,
97+
at::Tensor& sparse_page_table,
98+
at::Tensor& sparse_cache_seqlens,
99+
const at::Tensor& req_to_token,
100+
const at::Tensor& req_indices,
101+
at::Tensor& winfo_q_indices,
102+
at::Tensor& winfo_kv_offsets,
103+
at::Tensor& winfo_kv_lens,
104+
at::Tensor& winfo_num_workload,
105+
at::Tensor& winfo_chunk_size,
106+
const int64_t page_size,
107+
const int64_t num_kv_heads,
108+
const int64_t topk_val,
109+
const int64_t page_reserved_bos,
110+
const int64_t page_reserved_eos,
111+
const int64_t max_chunk_size,
112+
const int64_t min_chunk_size
113+
);
89114

115+
void sglang_plan_prefill_fa3(
116+
const at::Tensor& cached_seq_lens,
117+
const at::Tensor& cu_seqlens_q,
118+
const at::Tensor& req_to_token,
119+
const at::Tensor& req_indices,
120+
at::Tensor& page_table,
121+
at::Tensor& batch_table,
122+
const int64_t page_size,
123+
const int64_t num_kv_heads
124+
);
125+
126+
at::Tensor Chunkwise_HN2NH_Transpose_FA3(
127+
const at::Tensor& x,
128+
const at::Tensor& indptr,
129+
const at::Tensor& batch_table,
130+
const int64_t num_qo_heads,
131+
const int64_t num_kv_heads,
132+
const int64_t head_dim
133+
);

0 commit comments

Comments
 (0)