Skip to content

Commit 2ccba55

Browse files
committed
feat - support mla cudagraph
1 parent 239fe45 commit 2ccba55

File tree

26 files changed

+394
-206
lines changed

26 files changed

+394
-206
lines changed

rtp_llm/cpp/devices/cuda_impl/CudaFlashInfer.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ void FlashInferAttnParams::fillParams(torch::Tensor sequence_lengths,
142142
torch::Tensor input_lengths,
143143
torch::Tensor kv_cache_block_id_host,
144144
int batch_size,
145-
int seq_size_per_block) {
145+
int seq_size_per_block,
146+
torch::Tensor prefix_lengths) {
146147
fillFlashInfer(nullptr,
147148
torchTensor2Buffer(sequence_lengths),
148149
torchTensor2Buffer(input_lengths),

rtp_llm/cpp/devices/cuda_impl/CudaFlashInfer.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ struct FlashInferAttnParams: ParamsBase {
109109
torch::Tensor input_lengths,
110110
torch::Tensor kv_cache_block_id_host,
111111
int batch_size,
112-
int seq_size_per_block) override;
112+
int seq_size_per_block,
113+
torch::Tensor prefix_lengths = torch::Tensor()) override;
113114
void fillFlashInfer(const BufferPtr& prefix_lengths_host,
114115
const BufferPtr& sequence_lengths_host,
115116
const BufferPtr& input_lengths_host,

rtp_llm/cpp/devices/cuda_impl/CudaGraphRunner.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ GraphBase* CudaDevice::getDeviceGraphRunner(const DeviceInitParams& params,
4040

4141
py::object CudaGraphRunner::normalForward(PyModelInputs& inputs) {
4242
auto attn_pyobj = py_attn_pyobj_method_(inputs, false);
43-
attn_pyobj.attr("prepare")(inputs);
43+
attn_pyobj.attr("prepare")(inputs.attention_inputs);
4444
return py_forward_method_(inputs, attn_pyobj);
4545
}
4646

@@ -100,7 +100,7 @@ void CudaGraphRunner::prepareInputs(PyModelInputs& inputs) {
100100
py_model_inputs_.attention_inputs.padding_offset,
101101
inputs.attention_inputs.padding_offset.size(0) * sizeof(int));
102102
auto attn_pyobj = graph_instances_[state_.current_real_graph_bs].mem_hold_.attn_pyobj_;
103-
attn_pyobj.attr("prepare_replay")(inputs);
103+
attn_pyobj.attr("prepare_replay")(inputs.attention_inputs);
104104
} else {
105105
auto& py_model_inputs_ = graph_instances_[state_.current_real_graph_seq_len].mem_hold_.py_model_inputs_;
106106

@@ -342,7 +342,7 @@ void CudaGraphRunner::initCapture() {
342342
initKernelInternalMemory();
343343
// get real output data type
344344
auto attn_pyobj = py_attn_pyobj_method_(capture_mem_hold_.py_model_inputs_, true);
345-
attn_pyobj.attr("prepare")(capture_mem_hold_.py_model_inputs_);
345+
attn_pyobj.attr("prepare")(capture_mem_hold_.py_model_inputs_.attention_inputs);
346346
RTP_LLM_LOG_INFO("initCapture forward for output datatype start");
347347
auto py_outputs_obj = py_forward_method_(capture_mem_hold_.py_model_inputs_, attn_pyobj);
348348
RTP_LLM_LOG_INFO("initCapture forward for output datatype end");
@@ -383,7 +383,7 @@ void CudaGraphRunner::captureOneGraphInstance(int key, const char* key_type) {
383383
// WarmUp twice
384384
RTP_LLM_LOG_INFO("WarmUp for %s %d start.", key_type, key);
385385
auto attn_pyobj = graph_instances_[key].mem_hold_.attn_pyobj_;
386-
attn_pyobj.attr("prepare")(inputs);
386+
attn_pyobj.attr("prepare")(inputs.attention_inputs);
387387
py_forward_method_(inputs, attn_pyobj);
388388
py_forward_method_(inputs, attn_pyobj);
389389
RTP_LLM_LOG_INFO("WarmUp for %s %d successfully.", key_type, key);

rtp_llm/cpp/models/PyWrappedModel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ GptModelOutputs PyWrappedModel::forward(const GptModelInputs& inputs) {
257257
} else {
258258
DevicePerfWrapper wrapper(device_, "normal forward");
259259
auto attn_pyobj = py_model_.attr("prepare_fmha_impl")(py_model_inputs, false);
260-
attn_pyobj.attr("prepare")(py_model_inputs);
260+
attn_pyobj.attr("prepare")(py_model_inputs.attention_inputs);
261261
auto py_model_forward = py_model_.attr("forward");
262262
auto outputs = py_model_forward(py_model_inputs, attn_pyobj);
263263
py_model_outputs = outputs.cast<PyModelOutputs>();

rtp_llm/models/deepseek_v2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,9 @@ def _create_config(cls, ckpt_path: str):
523523
DeepSeekV2._from_hf(config, ckpt_path)
524524
return config
525525

526+
def support_cuda_graph(self) -> bool:
527+
return True
528+
526529
def _create_python_model(self) -> Optional[GptModelBase]:
527530
self.py_model = GenericMoeModel(self.config, self.weight)
528531

rtp_llm/models_py/bindings/OpDefs.cc

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,19 +46,6 @@ void registerPyOpDefs(pybind11::module& m) {
4646
pybind11::arg("seq_size_per_block"),
4747
"Fill parameters for CUDA graph execution");
4848

49-
pybind11::class_<MlaParams, std::shared_ptr<MlaParams>, rtp_llm::ParamsBase>(m, "MlaParams")
50-
.def(pybind11::init<>())
51-
.def_readonly("batch_indice", &MlaParams::batch_indice)
52-
.def_readonly("positions", &MlaParams::positions)
53-
.def_readonly("paged_kv_last_page_len", &MlaParams::paged_kv_last_page_len)
54-
.def_readonly("kvlen", &MlaParams::kvlen)
55-
.def_readonly("page_indice", &MlaParams::page_indice)
56-
.def_readonly("reuse_cache_page_indice", &MlaParams::reuse_cache_page_indice)
57-
.def_readonly("decode_page_indptr", &MlaParams::decode_page_indptr)
58-
.def_readonly("prefill_page_indptr", &MlaParams::prefill_page_indptr)
59-
.def_readonly("qo_indptr", &MlaParams::qo_indptr)
60-
.def_readonly("batch_reuse_info_vec", &MlaParams::batch_reuse_info_vec);
61-
6249
pybind11::class_<PyPrefillCudaGaphCopyParams>(m, "PyPrefillCudaGaphCopyParams")
6350
.def(pybind11::init<>())
6451
.def_readonly("cuda_graph_prefill_batch_size", &PyPrefillCudaGaphCopyParams::cuda_graph_prefill_batch_size)

rtp_llm/models_py/bindings/OpDefs.h

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,6 @@
88
#include "rtp_llm/models_py/bindings/ParamsBase.h"
99
#include "rtp_llm/cpp/utils/Logger.h"
1010
namespace torch_ext {
11-
struct MlaParams: public rtp_llm::ParamsBase {
12-
torch::Tensor batch_indice;
13-
torch::Tensor positions;
14-
torch::Tensor paged_kv_last_page_len;
15-
torch::Tensor kvlen;
16-
torch::Tensor page_indice;
17-
torch::Tensor reuse_cache_page_indice;
18-
torch::Tensor decode_page_indptr;
19-
torch::Tensor prefill_page_indptr;
20-
torch::Tensor qo_indptr;
21-
torch::Tensor batch_reuse_info_vec;
22-
23-
// Hidden field to keep FlashInferMlaAttnParams object alive
24-
// This ensures the underlying buffers (buf_d, buf_h) are not deallocated
25-
std::shared_ptr<void> _params_holder;
26-
};
2711

2812
struct KVCache {
2913
torch::Tensor k_cache_base;
@@ -96,7 +80,7 @@ struct PyAttentionInputs {
9680
std::optional<PyCacheStoreInputs> cache_store_inputs;
9781

9882
std::optional<PyPrefillCudaGaphCopyParams> prefill_cuda_graph_copy_params;
99-
bool is_s_padded = false;
83+
bool is_s_padded = false;
10084
};
10185

10286
struct BertEmbeddingInputs {

rtp_llm/models_py/bindings/ParamsBase.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ class ParamsBase {
1212
torch::Tensor input_lengths,
1313
torch::Tensor kv_cache_block_id_host,
1414
int batch_size,
15-
int seq_size_per_block) {};
15+
int seq_size_per_block,
16+
torch::Tensor prefix_lengths = torch::Tensor()) {};
1617
// check whether the parmas can be recycled automatically.
1718
virtual bool check_recycle() {
1819
return false;

rtp_llm/models_py/bindings/cuda/FlashInferMlaParams.cc

Lines changed: 68 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ using namespace torch_ext;
1111

1212
namespace rtp_llm {
1313

14+
static const int MIN_CACHE_PAGE_NUM = 1024 * 1024;
15+
// static const int MIN_CACHE_BATCH_SIZE = 256;
16+
// static const int MIN_CACHE_INPUT_TOKEN_NUM = 512;
1417
std::tuple<torch::Tensor, std::vector<torch::Tensor>>
1518
FlashInferMlaAttnParams::allocateManyBuffer(const std::vector<std::vector<int64_t>>& shapes, bool is_device) {
1619
std::vector<torch::Tensor> tensors;
@@ -65,7 +68,7 @@ void FlashInferMlaAttnParams::ensureTensorSize(
6568
// Update max sizes
6669
max_batch_size_ = std::max(max_batch_size_, batch_size);
6770
max_input_token_num_ = std::max(max_input_token_num_, input_token_num);
68-
max_page_num_ = std::max(max_page_num_, page_num);
71+
max_page_num_ = std::max(max_page_num_, MIN_CACHE_PAGE_NUM);
6972
max_reuse_page_num_ = std::max(max_reuse_page_num_, reuse_page_num);
7073
max_batch_reuse_info_ = std::max(max_batch_reuse_info_, batch_reuse_info_size);
7174

@@ -317,11 +320,12 @@ void FlashInferMlaAttnParams::refreshBuffer(
317320
batch_reuse_info_vec_h.unsafeGetTensorImpl()->set_sizes_contiguous(shape);
318321
}
319322

320-
MlaParams FlashInferMlaAttnParams::fillParams(torch::Tensor t_prefix_lengths,
321-
torch::Tensor t_sequence_lengths,
322-
torch::Tensor t_input_lengths,
323-
torch::Tensor t_kv_cache_block_id_host,
324-
int seq_size_per_block) {
323+
void FlashInferMlaAttnParams::fillParams(torch::Tensor t_sequence_lengths,
324+
torch::Tensor t_input_lengths,
325+
torch::Tensor t_kv_cache_block_id_host,
326+
int t_batch_size,
327+
int seq_size_per_block,
328+
torch::Tensor t_prefix_lengths) {
325329
const int batch_size = t_input_lengths.size(0);
326330

327331
// First pass: calculate required sizes accurately
@@ -370,54 +374,77 @@ MlaParams FlashInferMlaAttnParams::fillParams(torch::Tensor t_prefix_lengths,
370374
// Refresh buffer (copy to DEVICE and update shapes)
371375
refreshBuffer(batch_size, input_token_num, page_num, reuse_page_num, batch_reuse_info_size);
372376

373-
batch_indice = batch_indice_d;
374-
page_indice = page_indice_d;
375-
reuse_cache_page_indice = reuse_page_num > 0 ? reuse_cache_page_indice_d : torch::Tensor();
376-
decode_page_indptr = decode_page_indptr_d;
377-
prefill_page_indptr = prefill_page_indptr_d;
378-
paged_kv_last_page_len = paged_kv_last_page_len_d;
379-
qo_indptr = qo_indptr_d;
380-
kvlen = kvlen_d;
381-
positions = positions_d;
382-
batch_reuse_info_vec = batch_size > 0 ? batch_reuse_info_vec_d : torch::Tensor();
383-
384-
// Return MlaParams with DEVICE tensors
385-
MlaParams params;
386-
params.batch_indice = batch_indice_d;
387-
params.page_indice = page_indice_d;
388-
params.reuse_cache_page_indice = reuse_page_num > 0 ? reuse_cache_page_indice_d : torch::Tensor();
389-
params.decode_page_indptr = decode_page_indptr_d;
390-
params.prefill_page_indptr = prefill_page_indptr_d;
391-
params.paged_kv_last_page_len = paged_kv_last_page_len_d;
392-
params.qo_indptr = qo_indptr_d;
393-
params.kvlen = kvlen_d;
394-
params.positions = positions_d;
395-
params.batch_reuse_info_vec = batch_size > 0 ? batch_reuse_info_vec_d : torch::Tensor();
396-
397-
return params;
377+
return;
398378
}
399379

400380
void registerPyFlashInferMlaParams(pybind11::module& m) {
381+
pybind11::class_<FlashInferMlaAttnParams, std::shared_ptr<FlashInferMlaAttnParams>, rtp_llm::ParamsBase>(
382+
m, "FlashInferMlaAttnParams")
383+
.def(pybind11::init<>())
384+
// HOST tensors (_h suffix)
385+
.def_readonly("batch_indice_h", &FlashInferMlaAttnParams::batch_indice_h, "Batch indices on HOST")
386+
.def_readonly("page_indice_h", &FlashInferMlaAttnParams::page_indice_h, "Page indices on HOST")
387+
.def_readonly("reuse_cache_page_indice_h",
388+
&FlashInferMlaAttnParams::reuse_cache_page_indice_h,
389+
"Reuse cache page indices on HOST")
390+
.def_readonly(
391+
"decode_page_indptr_h", &FlashInferMlaAttnParams::decode_page_indptr_h, "Decode page indptr on HOST")
392+
.def_readonly(
393+
"prefill_page_indptr_h", &FlashInferMlaAttnParams::prefill_page_indptr_h, "Prefill page indptr on HOST")
394+
.def_readonly("paged_kv_last_page_len_h",
395+
&FlashInferMlaAttnParams::paged_kv_last_page_len_h,
396+
"Paged KV last page length on HOST")
397+
.def_readonly("qo_indptr_h", &FlashInferMlaAttnParams::qo_indptr_h, "Query/output indptr on HOST")
398+
.def_readonly("kvlen_h", &FlashInferMlaAttnParams::kvlen_h, "KV length on HOST")
399+
.def_readonly("positions_h", &FlashInferMlaAttnParams::positions_h, "Positions on HOST")
400+
.def_readonly("batch_reuse_info_vec_h",
401+
&FlashInferMlaAttnParams::batch_reuse_info_vec_h,
402+
"Batch reuse info vector on HOST")
403+
// DEVICE tensors (_d suffix)
404+
.def_readonly("batch_indice_d", &FlashInferMlaAttnParams::batch_indice_d, "Batch indices on DEVICE")
405+
.def_readonly("page_indice_d", &FlashInferMlaAttnParams::page_indice_d, "Page indices on DEVICE")
406+
.def_readonly("reuse_cache_page_indice_d",
407+
&FlashInferMlaAttnParams::reuse_cache_page_indice_d,
408+
"Reuse cache page indices on DEVICE")
409+
.def_readonly(
410+
"decode_page_indptr_d", &FlashInferMlaAttnParams::decode_page_indptr_d, "Decode page indptr on DEVICE")
411+
.def_readonly(
412+
"prefill_page_indptr_d", &FlashInferMlaAttnParams::prefill_page_indptr_d, "Prefill page indptr on DEVICE")
413+
.def_readonly("paged_kv_last_page_len_d",
414+
&FlashInferMlaAttnParams::paged_kv_last_page_len_d,
415+
"Paged KV last page length on DEVICE")
416+
.def_readonly("qo_indptr_d", &FlashInferMlaAttnParams::qo_indptr_d, "Query/output indptr on DEVICE")
417+
.def_readonly("kvlen_d", &FlashInferMlaAttnParams::kvlen_d, "KV length on DEVICE")
418+
.def_readonly("positions_d", &FlashInferMlaAttnParams::positions_d, "Positions on DEVICE")
419+
.def_readonly("batch_reuse_info_vec_d",
420+
&FlashInferMlaAttnParams::batch_reuse_info_vec_d,
421+
"Batch reuse info vector on DEVICE");
422+
401423
m.def(
402424
"fill_mla_params",
403-
[](torch::Tensor t_prefill_lengths,
404-
torch::Tensor t_sequence_lengths,
425+
[](torch::Tensor t_sequence_lengths,
405426
torch::Tensor t_input_lengths,
406427
torch::Tensor t_kv_cache_block_id_host,
407-
int seq_size_per_block) {
408-
auto params = std::make_shared<rtp_llm::FlashInferMlaAttnParams>();
409-
auto mla_params = params->fillParams(
410-
t_prefill_lengths, t_sequence_lengths, t_input_lengths, t_kv_cache_block_id_host, seq_size_per_block);
428+
int batch_size,
429+
int seq_size_per_block,
430+
torch::Tensor t_prefix_lengths) {
431+
auto params = std::make_shared<rtp_llm::FlashInferMlaAttnParams>();
432+
params->fillParams(t_sequence_lengths,
433+
t_input_lengths,
434+
t_kv_cache_block_id_host,
435+
batch_size,
436+
seq_size_per_block,
437+
t_prefix_lengths);
411438
// Store the params object in _params_holder to keep it alive
412439
// This ensures the underlying buffers (buf_d, buf_h) are not deallocated
413-
mla_params._params_holder = std::static_pointer_cast<void>(params);
414-
return mla_params;
440+
return params;
415441
},
416-
pybind11::arg("t_prefill_lengths"),
417442
pybind11::arg("t_sequence_lengths"),
418443
pybind11::arg("t_input_lengths"),
419444
pybind11::arg("t_kv_cache_block_id_host"),
420-
pybind11::arg("seq_size_per_block"));
445+
pybind11::arg("batch_size"),
446+
pybind11::arg("seq_size_per_block"),
447+
pybind11::arg("t_prefix_lengths"));
421448
}
422449

423450
} // namespace rtp_llm

rtp_llm/models_py/bindings/cuda/FlashInferMlaParams.h

Lines changed: 29 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,40 +17,6 @@ class FlashInferMlaAttnParams: public ParamsBase {
1717
torch::Tensor buf_h; // Large continuous HOST buffer (pinned memory)
1818
torch::Tensor buf_d; // Large continuous DEVICE buffer
1919

20-
// Tensor views into buf_h and buf_d
21-
torch::Tensor batch_indice_h;
22-
torch::Tensor page_indice_h;
23-
torch::Tensor reuse_cache_page_indice_h;
24-
torch::Tensor decode_page_indptr_h;
25-
torch::Tensor prefill_page_indptr_h;
26-
torch::Tensor paged_kv_last_page_len_h;
27-
torch::Tensor qo_indptr_h;
28-
torch::Tensor kvlen_h;
29-
torch::Tensor positions_h;
30-
torch::Tensor batch_reuse_info_vec_h;
31-
32-
torch::Tensor batch_indice_d;
33-
torch::Tensor page_indice_d;
34-
torch::Tensor reuse_cache_page_indice_d;
35-
torch::Tensor decode_page_indptr_d;
36-
torch::Tensor prefill_page_indptr_d;
37-
torch::Tensor paged_kv_last_page_len_d;
38-
torch::Tensor qo_indptr_d;
39-
torch::Tensor kvlen_d;
40-
torch::Tensor positions_d;
41-
torch::Tensor batch_reuse_info_vec_d;
42-
43-
torch::Tensor batch_indice;
44-
torch::Tensor positions;
45-
torch::Tensor paged_kv_last_page_len;
46-
torch::Tensor kvlen;
47-
torch::Tensor page_indice;
48-
torch::Tensor reuse_cache_page_indice;
49-
torch::Tensor decode_page_indptr;
50-
torch::Tensor prefill_page_indptr;
51-
torch::Tensor qo_indptr;
52-
torch::Tensor batch_reuse_info_vec;
53-
5420
// Reserved sizes
5521
int max_batch_size_ = 0;
5622
int max_input_token_num_ = 0;
@@ -83,11 +49,35 @@ class FlashInferMlaAttnParams: public ParamsBase {
8349
ensureTensorSize(int batch_size, int input_token_num, int page_num, int reuse_page_num, int batch_reuse_info_size);
8450

8551
public:
86-
MlaParams fillParams(torch::Tensor t_prefix_lengths,
87-
torch::Tensor t_sequence_lengths,
88-
torch::Tensor t_input_lengths,
89-
torch::Tensor t_kv_cache_block_id_host,
90-
int seq_size_per_block);
52+
// Tensor views into buf_h and buf_d
53+
torch::Tensor batch_indice_h;
54+
torch::Tensor page_indice_h;
55+
torch::Tensor reuse_cache_page_indice_h;
56+
torch::Tensor decode_page_indptr_h;
57+
torch::Tensor prefill_page_indptr_h;
58+
torch::Tensor paged_kv_last_page_len_h;
59+
torch::Tensor qo_indptr_h;
60+
torch::Tensor kvlen_h;
61+
torch::Tensor positions_h;
62+
torch::Tensor batch_reuse_info_vec_h;
63+
64+
torch::Tensor batch_indice_d;
65+
torch::Tensor page_indice_d;
66+
torch::Tensor reuse_cache_page_indice_d;
67+
torch::Tensor decode_page_indptr_d;
68+
torch::Tensor prefill_page_indptr_d;
69+
torch::Tensor paged_kv_last_page_len_d;
70+
torch::Tensor qo_indptr_d;
71+
torch::Tensor kvlen_d;
72+
torch::Tensor positions_d;
73+
torch::Tensor batch_reuse_info_vec_d;
74+
75+
void fillParams(torch::Tensor sequence_lengths,
76+
torch::Tensor input_lengths,
77+
torch::Tensor kv_cache_block_id_host,
78+
int batch_size,
79+
int seq_size_per_block,
80+
torch::Tensor prefix_lengths = torch::Tensor()) override;
9181
};
9282
void registerPyFlashInferMlaParams(pybind11::module& m);
9383

0 commit comments

Comments
 (0)