Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ..speculative.mtp import SampleStateTensorsMTP
from ..utils import make_weak_ref, piecewise_cuda_graph
from .llm_request import get_draft_token_length
from .mamba_cache_manager import MambaCacheManager, MambaHybridCacheManager
from .resource_manager import (BaseResourceManager, ResourceManager,
ResourceManagerType)
from .sampler import SampleStateTensors
Expand Down Expand Up @@ -450,6 +451,12 @@ def _get_padded_batch(self, batch: ScheduledRequests,
if spec_res_mgr:
spec_res_mgr.add_dummy_requests([CUDA_GRAPH_DUMMY_REQUEST_ID])

# handle special cases of padding requests + MambaCacheManager or MambaHybridCacheManager
if isinstance(kv_cache_manager,
(MambaCacheManager, MambaHybridCacheManager)):
kv_cache_manager.reorder_state_indices_when_padding_requests(
batch_size, padding_size)

self.padding_dummy_request.py_draft_tokens = [0] * runtime_draft_len
batch.generation_requests.extend([self.padding_dummy_request] *
padding_size)
Expand Down
20 changes: 20 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,14 @@ def __init__(
device=device,
dtype=torch.int32)

# only for `reorder_state_indices_when_padding_requests`
self.request_mask = torch.ones(max_batch_size,
dtype=torch.bool,
device=device)
self.state_indices_arange = torch.arange(max_batch_size,
dtype=torch.int32,
device=device)

def _prepare_mamba_cache_blocks(self, request_ids: List[int]):
state_indices = []
for r in request_ids:
Expand All @@ -127,6 +135,18 @@ def _prepare_mamba_cache_blocks(self, request_ids: List[int]):
state_indices, dtype=torch.int32, pin_memory=True),
non_blocking=True)

# When there exists padded requests, the state indices should not be repeated.
def reorder_state_indices_when_padding_requests(self, request_size,
padding_size):
if padding_size == 0:
return

self.request_mask[:] = True
self.request_mask[self.state_indices[:request_size]] = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are implicit CUDA synchronizations which negate the benefits of the overlap scheduler.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked whether there is CUDA synchronizations here and there is none.
Screenshot 2025-12-29 at 15 58 04

But anyway, there is a new method in the newer patch with lower overhead.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are synchronizations I said. The NVTX range "====A1==" is this code snippet I quoted.

image

There is no synchronization when I delete the code snippet. But my NVTX trace is generated before your new method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From your screenshot, I cannot say whether there is or is not newly introduced synchronization. Because there is launch overhead, synchronization (if exists) will return immediately and is not easy to notice.

I wanted to verify it by myself, but I faced an error, so please resolve the error first.

self.state_indices[request_size:request_size +
padding_size] = self.state_indices_arange[
self.request_mask][:padding_size]

def prepare_resources(self, scheduled_batch: ScheduledRequests):
context_ids = [
i.py_request_id for i in scheduled_batch.context_requests
Expand Down
11 changes: 7 additions & 4 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4602,10 +4602,12 @@ def test_bf16_4gpu(self, tp_size, pp_size, ep_size, cuda_graph,
model_path = f"{self.MODEL_PATH}/Qwen3-Next-80B-A3B-Instruct"
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6,
enable_block_reuse=False)
pytorch_config = dict(disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig(
max_batch_size=512, enable_padding=True)
if cuda_graph else None)
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig(
enable_padding=True,
batch_sizes=[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048])
if cuda_graph else None)

with LLM(
model_path,
Expand All @@ -4620,6 +4622,7 @@ def test_bf16_4gpu(self, tp_size, pp_size, ep_size, cuda_graph,
task.evaluate(llm)
mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN",
self.GSM8K_MAX_OUTPUT_LEN)
mocker.patch.object(GSM8K, "NUM_SAMPLES", 1319)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

Expand Down