From 08d4e368753ed9c16f64a7ca713834dac834172e Mon Sep 17 00:00:00 2001 From: jiant <107457950+JadoTu@users.noreply.github.com> Date: Wed, 10 Dec 2025 11:48:01 +0000 Subject: [PATCH 1/3] fix mamba_cache_manager when cuda_graph_padding && test cover this case Signed-off-by: jiant <107457950+JadoTu@users.noreply.github.com> --- .../_torch/pyexecutor/cuda_graph_runner.py | 7 +++++++ .../_torch/pyexecutor/mamba_cache_manager.py | 20 +++++++++++++++++++ .../defs/accuracy/test_llm_api_pytorch.py | 11 ++++++---- 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index b8e2754a9cb..394dd56f53b 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -15,6 +15,7 @@ from ..modules.multi_stream_utils import with_multi_stream from ..speculative.eagle3 import Eagle3ResourceManager from ..utils import make_weak_ref, piecewise_cuda_graph +from .mamba_cache_manager import MambaCacheManager, MambaHybridCacheManager from .resource_manager import (BaseResourceManager, ResourceManager, ResourceManagerType) from .scheduler import ScheduledRequests @@ -389,6 +390,12 @@ def _get_padded_batch(self, batch: ScheduledRequests, if spec_res_mgr: spec_res_mgr.add_dummy_requests([CUDA_GRAPH_DUMMY_REQUEST_ID]) + # handle special cases of padding requests + MambaCacheManager or MambaHybridCacheManager + if isinstance(kv_cache_manager, + (MambaCacheManager, MambaHybridCacheManager)): + kv_cache_manager.reorder_state_indices_when_padding_requests( + batch_size, padding_size) + self.padding_dummy_request.py_draft_tokens = [0] * runtime_draft_len batch.generation_requests.extend([self.padding_dummy_request] * padding_size) diff --git a/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py index e33b6d36bfa..75620e9f714 100644 --- a/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py @@ -110,6 +110,14 @@ def __init__( device=device, dtype=torch.int32) + # only for `reorder_state_indices_when_padding_requests` + self.request_mask = torch.ones(max_batch_size, + dtype=torch.bool, + device=device) + self.state_indices_arange = torch.arange(max_batch_size, + dtype=torch.int32, + device=device) + def _prepare_mamba_cache_blocks(self, request_ids: List[int]): state_indices = [] for r in request_ids: @@ -126,6 +134,18 @@ def _prepare_mamba_cache_blocks(self, request_ids: List[int]): self.state_indices[:len(state_indices)] = torch.as_tensor( state_indices, dtype=torch.int32, device=self.ssm_states.device) + # When there exists padded requests, the state indices should not be repeated. + def reorder_state_indices_when_padding_requests(self, request_size, + padding_size): + if padding_size == 0: + return + + self.request_mask[:] = True + self.request_mask[self.state_indices[:request_size]] = False + self.state_indices[request_size:request_size + + padding_size] = self.state_indices_arange[ + self.request_mask][:padding_size] + def prepare_resources(self, scheduled_batch: ScheduledRequests): context_ids = [ i.py_request_id for i in scheduled_batch.context_requests diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 35e60e04360..345a3ba8744 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -4536,10 +4536,12 @@ def test_bf16_4gpu(self, tp_size, pp_size, ep_size, cuda_graph, model_path = f"{self.MODEL_PATH}/Qwen3-Next-80B-A3B-Instruct" kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6, enable_block_reuse=False) - pytorch_config = dict(disable_overlap_scheduler=not overlap_scheduler, - cuda_graph_config=CudaGraphConfig( - max_batch_size=512, enable_padding=True) - if cuda_graph else None) + pytorch_config = dict( + disable_overlap_scheduler=not overlap_scheduler, + cuda_graph_config=CudaGraphConfig( + enable_padding=True, + batch_sizes=[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]) + if cuda_graph else None) with LLM( model_path, @@ -4554,6 +4556,7 @@ def test_bf16_4gpu(self, tp_size, pp_size, ep_size, cuda_graph, task.evaluate(llm) mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN", self.GSM8K_MAX_OUTPUT_LEN) + mocker.patch.object(GSM8K, "NUM_SAMPLES", 1319) task = GSM8K(self.MODEL_NAME) task.evaluate(llm) From 9edd777014ca430b3189e9dff44c49854eb9a670 Mon Sep 17 00:00:00 2001 From: jiant <107457950+JadoTu@users.noreply.github.com> Date: Wed, 24 Dec 2025 03:15:33 +0000 Subject: [PATCH 2/3] pass pre-commit check Signed-off-by: jiant <107457950+JadoTu@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index c5c25618a4b..cb05c3b2efb 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -17,8 +17,8 @@ from ..speculative.eagle3 import Eagle3ResourceManager from ..speculative.mtp import SampleStateTensorsMTP from ..utils import make_weak_ref, piecewise_cuda_graph -from .mamba_cache_manager import MambaCacheManager, MambaHybridCacheManager from .llm_request import get_draft_token_length +from .mamba_cache_manager import MambaCacheManager, MambaHybridCacheManager from .resource_manager import (BaseResourceManager, ResourceManager, ResourceManagerType) from .sampler import SampleStateTensors From e876a2096aee8f9be39f2b742bcf17a2f82fb0d9 Mon Sep 17 00:00:00 2001 From: jiant <107457950+JadoTu@users.noreply.github.com> Date: Mon, 29 Dec 2025 10:04:49 +0000 Subject: [PATCH 3/3] new strategy of low overhead Signed-off-by: jiant <107457950+JadoTu@users.noreply.github.com> --- .../_torch/pyexecutor/cuda_graph_runner.py | 5 ++--- .../_torch/pyexecutor/mamba_cache_manager.py | 21 ++++++++----------- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index 6fba4826e12..852f2e063da 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -18,7 +18,7 @@ from ..speculative.mtp import SampleStateTensorsMTP from ..utils import make_weak_ref, piecewise_cuda_graph from .llm_request import get_draft_token_length -from .mamba_cache_manager import MambaCacheManager, MambaHybridCacheManager +from .mamba_cache_manager import MambaCacheManager from .resource_manager import (BaseResourceManager, ResourceManager, ResourceManagerType) from .sampler import SampleStateTensors @@ -452,8 +452,7 @@ def _get_padded_batch(self, batch: ScheduledRequests, spec_res_mgr.add_dummy_requests([CUDA_GRAPH_DUMMY_REQUEST_ID]) # handle special cases of padding requests + MambaCacheManager or MambaHybridCacheManager - if isinstance(kv_cache_manager, - (MambaCacheManager, MambaHybridCacheManager)): + if isinstance(kv_cache_manager, MambaCacheManager): kv_cache_manager.reorder_state_indices_when_padding_requests( batch_size, padding_size) diff --git a/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py index c9131e296df..f4af1975a1c 100644 --- a/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py @@ -110,14 +110,6 @@ def __init__( device=device, dtype=torch.int32) - # only for `reorder_state_indices_when_padding_requests` - self.request_mask = torch.ones(max_batch_size, - dtype=torch.bool, - device=device) - self.state_indices_arange = torch.arange(max_batch_size, - dtype=torch.int32, - device=device) - def _prepare_mamba_cache_blocks(self, request_ids: List[int]): state_indices = [] for r in request_ids: @@ -141,11 +133,16 @@ def reorder_state_indices_when_padding_requests(self, request_size, if padding_size == 0: return - self.request_mask[:] = True - self.request_mask[self.state_indices[:request_size]] = False + # we can use mamba_cache_free_blocks for padding_requests + assert len( + self.mamba_cache_free_blocks + ) >= padding_size, "Padding requests run out of available mamba cache blocks" self.state_indices[request_size:request_size + - padding_size] = self.state_indices_arange[ - self.request_mask][:padding_size] + padding_size] = torch.tensor( + self.mamba_cache_free_blocks[:padding_size], + dtype=self.state_indices.dtype, + pin_memory=True).to(self.state_indices.device, + non_blocking=True) def prepare_resources(self, scheduled_batch: ScheduledRequests): context_ids = [