diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index 187566f62eb..852f2e063da 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -18,6 +18,7 @@ from ..speculative.mtp import SampleStateTensorsMTP from ..utils import make_weak_ref, piecewise_cuda_graph from .llm_request import get_draft_token_length +from .mamba_cache_manager import MambaCacheManager from .resource_manager import (BaseResourceManager, ResourceManager, ResourceManagerType) from .sampler import SampleStateTensors @@ -450,6 +451,11 @@ def _get_padded_batch(self, batch: ScheduledRequests, if spec_res_mgr: spec_res_mgr.add_dummy_requests([CUDA_GRAPH_DUMMY_REQUEST_ID]) + # handle special cases of padding requests + MambaCacheManager or MambaHybridCacheManager + if isinstance(kv_cache_manager, MambaCacheManager): + kv_cache_manager.reorder_state_indices_when_padding_requests( + batch_size, padding_size) + self.padding_dummy_request.py_draft_tokens = [0] * runtime_draft_len batch.generation_requests.extend([self.padding_dummy_request] * padding_size) diff --git a/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py index 767a70dd416..f4af1975a1c 100644 --- a/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py @@ -127,6 +127,23 @@ def _prepare_mamba_cache_blocks(self, request_ids: List[int]): state_indices, dtype=torch.int32, pin_memory=True), non_blocking=True) + # When there exists padded requests, the state indices should not be repeated. + def reorder_state_indices_when_padding_requests(self, request_size, + padding_size): + if padding_size == 0: + return + + # we can use mamba_cache_free_blocks for padding_requests + assert len( + self.mamba_cache_free_blocks + ) >= padding_size, "Padding requests run out of available mamba cache blocks" + self.state_indices[request_size:request_size + + padding_size] = torch.tensor( + self.mamba_cache_free_blocks[:padding_size], + dtype=self.state_indices.dtype, + pin_memory=True).to(self.state_indices.device, + non_blocking=True) + def prepare_resources(self, scheduled_batch: ScheduledRequests): context_ids = [ i.py_request_id for i in scheduled_batch.context_requests diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 76f45b15137..909d18ceac4 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -4602,10 +4602,12 @@ def test_bf16_4gpu(self, tp_size, pp_size, ep_size, cuda_graph, model_path = f"{self.MODEL_PATH}/Qwen3-Next-80B-A3B-Instruct" kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6, enable_block_reuse=False) - pytorch_config = dict(disable_overlap_scheduler=not overlap_scheduler, - cuda_graph_config=CudaGraphConfig( - max_batch_size=512, enable_padding=True) - if cuda_graph else None) + pytorch_config = dict( + disable_overlap_scheduler=not overlap_scheduler, + cuda_graph_config=CudaGraphConfig( + enable_padding=True, + batch_sizes=[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]) + if cuda_graph else None) with LLM( model_path, @@ -4620,6 +4622,7 @@ def test_bf16_4gpu(self, tp_size, pp_size, ep_size, cuda_graph, task.evaluate(llm) mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN", self.GSM8K_MAX_OUTPUT_LEN) + mocker.patch.object(GSM8K, "NUM_SAMPLES", 1319) task = GSM8K(self.MODEL_NAME) task.evaluate(llm)