Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,57 @@ class _FlashInferPlanner:
"""A class interface to handle flashinfer-related planning/wrapping operations."""

workspace_buffer: Optional[torch.Tensor]
paged_kv_indptr_buffer: Optional[torch.Tensor]
paged_kv_indices_buffer: Optional[torch.Tensor]
paged_kv_last_page_len_buffer: Optional[torch.Tensor]
prefill_wrapper: Optional[flashinfer.BatchPrefillWithPagedKVCacheWrapper]
decode_wrapper: Optional[flashinfer.BatchDecodeWithPagedKVCacheWrapper]
cached_decode_wrappers: Dict[PlanParams, flashinfer.BatchDecodeWithPagedKVCacheWrapper]
cached_cuda_graph_decode_wrappers: Dict[
PlanParams, flashinfer.BatchDecodeWithPagedKVCacheWrapper
]
plan_params: Optional[PlanParams]

def __init__(self):
self.workspace_buffer = None
self.paged_kv_indptr_buffer = None
self.paged_kv_indices_buffer = None
self.paged_kv_last_page_len_buffer = None
self.prefill_wrapper = None
self.decode_wrapper = None
self.cached_decode_wrappers = {}
self.cached_cuda_graph_decode_wrappers = {}
self.plan_params = None

def _init_decode_wrapper(self):
def _init_decode_wrapper(
self,
use_cuda_graph: bool = False,
num_pages: Optional[int] = None,
batch_size: Optional[int] = None,
):
assert self.workspace_buffer is not None
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer, "NHD", use_tensor_cores=True
)
if use_cuda_graph:
assert self.paged_kv_indptr_buffer is not None
assert self.paged_kv_indices_buffer is not None
assert self.paged_kv_last_page_len_buffer is not None
if len(self.paged_kv_indices_buffer) < num_pages:
ad_logger.info(
f"Resizing paged_kv_indices_buffer from {len(self.paged_kv_indices_buffer)} to {num_pages}"
)
self.paged_kv_indices_buffer.resize_(num_pages)
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
use_cuda_graph=True,
paged_kv_indptr_buffer=self.paged_kv_indptr_buffer[: batch_size + 1],
paged_kv_indices_buffer=self.paged_kv_indices_buffer[:num_pages],
paged_kv_last_page_len_buffer=self.paged_kv_last_page_len_buffer[:batch_size],
use_tensor_cores=True,
)
else:
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
use_tensor_cores=True,
)

def init_workspace(self, workspace_buffer: torch.Tensor):
self.__init__() # reset all state
Expand Down Expand Up @@ -96,7 +130,9 @@ def plan(
flashinfer.BatchDecodeWithPagedKVCacheWrapper,
]:
# plan decode helper function
def _plan_decode(wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper):
def _plan_decode(
wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper,
):
wrapper.plan(
kv_page_indptr,
kv_page_indices,
Expand All @@ -111,18 +147,26 @@ def _plan_decode(wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper):
)

# we want to plan during warm-up of cuda graph capture to ensure we have the plan cached
if cuda_graph_state.in_warm_up() and plan_params not in self.cached_decode_wrappers:
self.cached_decode_wrappers[plan_params] = self._init_decode_wrapper()
_plan_decode(self.cached_decode_wrappers[plan_params])
if (
cuda_graph_state.in_warm_up()
and plan_params not in self.cached_cuda_graph_decode_wrappers
):
# During CUDA graph capture, the metadata tensors provided by auto-deploy are stable.
wrapper = self._init_decode_wrapper(
use_cuda_graph=True,
num_pages=len(kv_page_indices),
batch_size=len(kv_page_indptr) - 1,
)
_plan_decode(wrapper)
self.cached_cuda_graph_decode_wrappers[plan_params] = wrapper

# check if we are in cuda graph capture and just return the pre-cached decode wrapper
if torch.cuda.is_current_stream_capturing() or cuda_graph_state.in_warm_up():
assert plan_params.is_generate, "Only generate is supported during cuda graph capture."
wrapper = self.cached_decode_wrappers[plan_params]
# copy the metadata to the wrapper to ensure it is up-to-date for graph replay!
wrapper._paged_kv_indptr_buf.copy_(kv_page_indptr)
wrapper._paged_kv_indices_buf.copy_(kv_page_indices)
wrapper._paged_kv_last_page_len_buf.copy_(kv_last_page_len)
wrapper = self.cached_cuda_graph_decode_wrappers[plan_params]
wrapper._paged_kv_indptr_buf[: len(kv_page_indptr)].copy_(kv_page_indptr)
wrapper._paged_kv_indices_buf[: len(kv_page_indices)].copy_(kv_page_indices)
wrapper._paged_kv_last_page_len_buf[: len(kv_last_page_len)].copy_(kv_last_page_len)
return wrapper

# check for re-planning
Expand Down Expand Up @@ -167,14 +211,13 @@ def prepare_flashinfer_metadata(
https://docs.flashinfer.ai/api/prefill.html#flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper.plan
to understand the convention.
"""
# reset the planner
_GlobalFlashInferPlanner.reset()

# retrieve host-side metadata
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
num_seq = num_prefill + num_decode
num_tokens = num_prefill_tokens + num_decode

_GlobalFlashInferPlanner.reset()

qo_indptr = cu_seqlen[: num_seq + 1]

# NOTE: in theory we could easily precompute batch_indices. And positions is just position_ids
Expand Down Expand Up @@ -394,6 +437,15 @@ def _init_workspace(si: SequenceInfo) -> torch.Tensor:
# see https://github.com/NVIDIA/TensorRT-LLM/pull/3686
buffer = torch.empty(320 * 1024 * 1024, dtype=torch.uint8, device=si.device)
cls._get_planner().init_workspace(buffer)
cls._get_planner().paged_kv_indptr_buffer = torch.empty(
si.max_batch_size + 1, dtype=torch.int, device=si.device
)
cls._get_planner().paged_kv_indices_buffer = torch.empty(
si.num_pages, dtype=torch.int, device=si.device
)
cls._get_planner().paged_kv_last_page_len_buffer = torch.empty(
si.max_batch_size, dtype=torch.int, device=si.device
)
return buffer

return {"workspace_buffer": _init_workspace}
Expand Down
Loading