@@ -11,6 +11,9 @@ using namespace torch_ext;
1111
1212namespace rtp_llm {
1313
14+ static const int MIN_CACHE_PAGE_NUM = 1024 * 1024 ;
15+ // static const int MIN_CACHE_BATCH_SIZE = 256;
16+ // static const int MIN_CACHE_INPUT_TOKEN_NUM = 512;
1417std::tuple<torch::Tensor, std::vector<torch::Tensor>>
1518FlashInferMlaAttnParams::allocateManyBuffer (const std::vector<std::vector<int64_t >>& shapes, bool is_device) {
1619 std::vector<torch::Tensor> tensors;
@@ -65,7 +68,7 @@ void FlashInferMlaAttnParams::ensureTensorSize(
6568 // Update max sizes
6669 max_batch_size_ = std::max (max_batch_size_, batch_size);
6770 max_input_token_num_ = std::max (max_input_token_num_, input_token_num);
68- max_page_num_ = std::max (max_page_num_, page_num );
71+ max_page_num_ = std::max (max_page_num_, MIN_CACHE_PAGE_NUM );
6972 max_reuse_page_num_ = std::max (max_reuse_page_num_, reuse_page_num);
7073 max_batch_reuse_info_ = std::max (max_batch_reuse_info_, batch_reuse_info_size);
7174
@@ -317,11 +320,12 @@ void FlashInferMlaAttnParams::refreshBuffer(
317320 batch_reuse_info_vec_h.unsafeGetTensorImpl ()->set_sizes_contiguous (shape);
318321}
319322
320- MlaParams FlashInferMlaAttnParams::fillParams (torch::Tensor t_prefix_lengths,
321- torch::Tensor t_sequence_lengths,
322- torch::Tensor t_input_lengths,
323- torch::Tensor t_kv_cache_block_id_host,
324- int seq_size_per_block) {
323+ void FlashInferMlaAttnParams::fillParams (torch::Tensor t_sequence_lengths,
324+ torch::Tensor t_input_lengths,
325+ torch::Tensor t_kv_cache_block_id_host,
326+ int t_batch_size,
327+ int seq_size_per_block,
328+ torch::Tensor t_prefix_lengths) {
325329 const int batch_size = t_input_lengths.size (0 );
326330
327331 // First pass: calculate required sizes accurately
@@ -370,54 +374,77 @@ MlaParams FlashInferMlaAttnParams::fillParams(torch::Tensor t_prefix_lengths,
370374 // Refresh buffer (copy to DEVICE and update shapes)
371375 refreshBuffer (batch_size, input_token_num, page_num, reuse_page_num, batch_reuse_info_size);
372376
373- batch_indice = batch_indice_d;
374- page_indice = page_indice_d;
375- reuse_cache_page_indice = reuse_page_num > 0 ? reuse_cache_page_indice_d : torch::Tensor ();
376- decode_page_indptr = decode_page_indptr_d;
377- prefill_page_indptr = prefill_page_indptr_d;
378- paged_kv_last_page_len = paged_kv_last_page_len_d;
379- qo_indptr = qo_indptr_d;
380- kvlen = kvlen_d;
381- positions = positions_d;
382- batch_reuse_info_vec = batch_size > 0 ? batch_reuse_info_vec_d : torch::Tensor ();
383-
384- // Return MlaParams with DEVICE tensors
385- MlaParams params;
386- params.batch_indice = batch_indice_d;
387- params.page_indice = page_indice_d;
388- params.reuse_cache_page_indice = reuse_page_num > 0 ? reuse_cache_page_indice_d : torch::Tensor ();
389- params.decode_page_indptr = decode_page_indptr_d;
390- params.prefill_page_indptr = prefill_page_indptr_d;
391- params.paged_kv_last_page_len = paged_kv_last_page_len_d;
392- params.qo_indptr = qo_indptr_d;
393- params.kvlen = kvlen_d;
394- params.positions = positions_d;
395- params.batch_reuse_info_vec = batch_size > 0 ? batch_reuse_info_vec_d : torch::Tensor ();
396-
397- return params;
377+ return ;
398378}
399379
400380void registerPyFlashInferMlaParams (pybind11::module & m) {
381+ pybind11::class_<FlashInferMlaAttnParams, std::shared_ptr<FlashInferMlaAttnParams>, rtp_llm::ParamsBase>(
382+ m, " FlashInferMlaAttnParams" )
383+ .def (pybind11::init<>())
384+ // HOST tensors (_h suffix)
385+ .def_readonly (" batch_indice_h" , &FlashInferMlaAttnParams::batch_indice_h, " Batch indices on HOST" )
386+ .def_readonly (" page_indice_h" , &FlashInferMlaAttnParams::page_indice_h, " Page indices on HOST" )
387+ .def_readonly (" reuse_cache_page_indice_h" ,
388+ &FlashInferMlaAttnParams::reuse_cache_page_indice_h,
389+ " Reuse cache page indices on HOST" )
390+ .def_readonly (
391+ " decode_page_indptr_h" , &FlashInferMlaAttnParams::decode_page_indptr_h, " Decode page indptr on HOST" )
392+ .def_readonly (
393+ " prefill_page_indptr_h" , &FlashInferMlaAttnParams::prefill_page_indptr_h, " Prefill page indptr on HOST" )
394+ .def_readonly (" paged_kv_last_page_len_h" ,
395+ &FlashInferMlaAttnParams::paged_kv_last_page_len_h,
396+ " Paged KV last page length on HOST" )
397+ .def_readonly (" qo_indptr_h" , &FlashInferMlaAttnParams::qo_indptr_h, " Query/output indptr on HOST" )
398+ .def_readonly (" kvlen_h" , &FlashInferMlaAttnParams::kvlen_h, " KV length on HOST" )
399+ .def_readonly (" positions_h" , &FlashInferMlaAttnParams::positions_h, " Positions on HOST" )
400+ .def_readonly (" batch_reuse_info_vec_h" ,
401+ &FlashInferMlaAttnParams::batch_reuse_info_vec_h,
402+ " Batch reuse info vector on HOST" )
403+ // DEVICE tensors (_d suffix)
404+ .def_readonly (" batch_indice_d" , &FlashInferMlaAttnParams::batch_indice_d, " Batch indices on DEVICE" )
405+ .def_readonly (" page_indice_d" , &FlashInferMlaAttnParams::page_indice_d, " Page indices on DEVICE" )
406+ .def_readonly (" reuse_cache_page_indice_d" ,
407+ &FlashInferMlaAttnParams::reuse_cache_page_indice_d,
408+ " Reuse cache page indices on DEVICE" )
409+ .def_readonly (
410+ " decode_page_indptr_d" , &FlashInferMlaAttnParams::decode_page_indptr_d, " Decode page indptr on DEVICE" )
411+ .def_readonly (
412+ " prefill_page_indptr_d" , &FlashInferMlaAttnParams::prefill_page_indptr_d, " Prefill page indptr on DEVICE" )
413+ .def_readonly (" paged_kv_last_page_len_d" ,
414+ &FlashInferMlaAttnParams::paged_kv_last_page_len_d,
415+ " Paged KV last page length on DEVICE" )
416+ .def_readonly (" qo_indptr_d" , &FlashInferMlaAttnParams::qo_indptr_d, " Query/output indptr on DEVICE" )
417+ .def_readonly (" kvlen_d" , &FlashInferMlaAttnParams::kvlen_d, " KV length on DEVICE" )
418+ .def_readonly (" positions_d" , &FlashInferMlaAttnParams::positions_d, " Positions on DEVICE" )
419+ .def_readonly (" batch_reuse_info_vec_d" ,
420+ &FlashInferMlaAttnParams::batch_reuse_info_vec_d,
421+ " Batch reuse info vector on DEVICE" );
422+
401423 m.def (
402424 " fill_mla_params" ,
403- [](torch::Tensor t_prefill_lengths,
404- torch::Tensor t_sequence_lengths,
425+ [](torch::Tensor t_sequence_lengths,
405426 torch::Tensor t_input_lengths,
406427 torch::Tensor t_kv_cache_block_id_host,
407- int seq_size_per_block) {
408- auto params = std::make_shared<rtp_llm::FlashInferMlaAttnParams>();
409- auto mla_params = params->fillParams (
410- t_prefill_lengths, t_sequence_lengths, t_input_lengths, t_kv_cache_block_id_host, seq_size_per_block);
428+ int batch_size,
429+ int seq_size_per_block,
430+ torch::Tensor t_prefix_lengths) {
431+ auto params = std::make_shared<rtp_llm::FlashInferMlaAttnParams>();
432+ params->fillParams (t_sequence_lengths,
433+ t_input_lengths,
434+ t_kv_cache_block_id_host,
435+ batch_size,
436+ seq_size_per_block,
437+ t_prefix_lengths);
411438 // Store the params object in _params_holder to keep it alive
412439 // This ensures the underlying buffers (buf_d, buf_h) are not deallocated
413- mla_params._params_holder = std::static_pointer_cast<void >(params);
414- return mla_params;
440+ return params;
415441 },
416- pybind11::arg (" t_prefill_lengths" ),
417442 pybind11::arg (" t_sequence_lengths" ),
418443 pybind11::arg (" t_input_lengths" ),
419444 pybind11::arg (" t_kv_cache_block_id_host" ),
420- pybind11::arg (" seq_size_per_block" ));
445+ pybind11::arg (" batch_size" ),
446+ pybind11::arg (" seq_size_per_block" ),
447+ pybind11::arg (" t_prefix_lengths" ));
421448}
422449
423450} // namespace rtp_llm
0 commit comments