Skip to content

Commit 7dbecf7

Browse files
authored
[TRTLLM-4923][feat] Enable CUDA graphs for Nemotron-H (NVIDIA#5646)
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
1 parent 3c9dd5c commit 7dbecf7

File tree

3 files changed

+89
-20
lines changed

3 files changed

+89
-20
lines changed

tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -163,15 +163,8 @@ def forward(
163163
seqlen_split_size = [num_prefill_tokens, num_decode_tokens]
164164
batch_split_size = [num_prefills, num_decodes]
165165

166-
state_indices = attn_metadata.kv_cache_manager.get_state_indices()
167-
168-
# warm up does not prepare resources, so no relevant state indices
169-
is_warmup = state_indices.numel() == 0
170-
if is_warmup:
171-
# in this case, assume batch takes first indices in mamba cache
172-
state_indices = torch.arange(num_prefills + num_decodes,
173-
device=state_indices.device,
174-
dtype=state_indices.dtype)
166+
state_indices = attn_metadata.kv_cache_manager.get_state_indices(
167+
)[:num_prefills + num_decodes]
175168

176169
state_indices_p, state_indices_d = torch.split(state_indices,
177170
batch_split_size)

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -812,7 +812,7 @@ def __init__(
812812
self.mamba_cache_index: Dict[int, int] = {}
813813

814814
# mamba cache state indices
815-
self.state_indices: torch.Tensor = torch.tensor([],
815+
self.state_indices: torch.Tensor = torch.arange(max_batch_size,
816816
device=device,
817817
dtype=torch.int32)
818818

@@ -829,9 +829,8 @@ def prepare_mamba_cache_blocks(self, request_ids: List[int]):
829829
block = self.mamba_cache_free_blocks.pop()
830830
self.mamba_cache_index[r] = block
831831
state_indices.append(block)
832-
self.state_indices = torch.as_tensor(state_indices,
833-
dtype=torch.int32,
834-
device=self.ssm_states.device)
832+
self.state_indices[:len(state_indices)] = torch.as_tensor(
833+
state_indices, dtype=torch.int32, device=self.ssm_states.device)
835834

836835
def free_mamba_cache_blocks(self, request_id: int):
837836
if request_id in self.mamba_cache_index:

tests/unittest/_torch/modeling/test_modeling_nemotron_h.py

Lines changed: 84 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from tensorrt_llm import LLM
66
from tensorrt_llm.llmapi import KvCacheConfig
77
from tensorrt_llm.llmapi.llm import RequestOutput
8+
from tensorrt_llm.llmapi.llm_args import CudaGraphConfig
89
from tensorrt_llm.sampling_params import SamplingParams
910

1011

@@ -28,25 +29,36 @@ def extract_decode_logprobs(result: RequestOutput,
2829
return get_logprobs(token_ids, logits)
2930

3031

32+
def create_nemotron_h_llm(use_cuda_graph, disable_overlap_scheduler,
33+
max_batch_size):
34+
"""Create LLM with specific overlap scheduler setting"""
35+
model_dir = f"{llm_models_root(check=True)}/Nemotron-H-8B-Base-8K"
36+
return LLM(
37+
model=model_dir,
38+
tensor_parallel_size=1,
39+
max_batch_size=max_batch_size,
40+
cuda_graph_config=CudaGraphConfig() if use_cuda_graph else None,
41+
disable_overlap_scheduler=disable_overlap_scheduler,
42+
kv_cache_config=KvCacheConfig(enable_block_reuse=False),
43+
enable_trtllm_sampler=True,
44+
)
45+
46+
3147
@skip_gpu_memory_less_than(
3248
(2 * 8 + 1) * 2**30) # 8B, bf16, plus 1 GB for good measure
3349
def test_nemotron_h_correctness():
3450
# This test is close to memory limit on A30 (with 24GB), so empty cache first
3551
torch.cuda.empty_cache()
3652

37-
model_dir = f"{llm_models_root(check=True)}/Nemotron-H-8B-Base-8K"
3853
text_prompts = [
3954
"The future of AI is",
4055
"The president of the United States is",
4156
]
4257
num_prompts = len(text_prompts)
4358

44-
nemotron_h = LLM(
45-
model=model_dir,
46-
max_batch_size=num_prompts,
47-
kv_cache_config=KvCacheConfig(enable_block_reuse=False),
48-
enable_trtllm_sampler=True,
49-
)
59+
nemotron_h = create_nemotron_h_llm(use_cuda_graph=False,
60+
disable_overlap_scheduler=False,
61+
max_batch_size=num_prompts)
5062

5163
expected_completions = [
5264
" bright, with endless possibilities for innovation and growth",
@@ -223,3 +235,68 @@ def test_nemotron_h_correctness():
223235

224236
finally:
225237
nemotron_h.shutdown()
238+
239+
240+
def test_nemotron_h_cuda_graph_overlap_scheduler():
241+
prompts = [
242+
"Tell me something I don't know about the future of AI",
243+
"The president of the United States is",
244+
"The capital of France is",
245+
"Hello, this is a beautiful day and I'm eager to start my day and",
246+
]
247+
sampling_config = SamplingParams(max_tokens=12,
248+
temperature=0.0,
249+
return_generation_logits=True)
250+
251+
# Test without cg and overlap scheduler disabled
252+
with create_nemotron_h_llm(use_cuda_graph=False,
253+
disable_overlap_scheduler=True,
254+
max_batch_size=16) as llm:
255+
outputs_no_cg_no_overlap = llm.generate(prompts,
256+
sampling_params=sampling_config,
257+
use_tqdm=True)
258+
259+
# Test with cg and overlap scheduler disabled
260+
with create_nemotron_h_llm(use_cuda_graph=True,
261+
disable_overlap_scheduler=True,
262+
max_batch_size=16) as llm:
263+
outputs_with_cg_no_overlap = llm.generate(
264+
prompts, sampling_params=sampling_config, use_tqdm=True)
265+
266+
# Test with cg and overlap scheduler enabled
267+
with create_nemotron_h_llm(use_cuda_graph=True,
268+
disable_overlap_scheduler=False,
269+
max_batch_size=16) as llm:
270+
outputs_with_cg_with_overlap = llm.generate(
271+
prompts, sampling_params=sampling_config, use_tqdm=True)
272+
273+
# Verify outputs are consistent
274+
for (no_cg_no_overlap, with_cg_no_overlap,
275+
with_cg_with_overlap) in zip(outputs_no_cg_no_overlap,
276+
outputs_with_cg_no_overlap,
277+
outputs_with_cg_with_overlap):
278+
279+
assert (no_cg_no_overlap.outputs[0].text ==
280+
with_cg_no_overlap.outputs[0].text)
281+
assert (with_cg_no_overlap.outputs[0].text ==
282+
with_cg_with_overlap.outputs[0].text)
283+
284+
# similar to other unittests comparing with / without CG, compare logits of first generation step (2nd generated token)
285+
torch.testing.assert_close(
286+
no_cg_no_overlap.outputs[0].generation_logits[1, :],
287+
with_cg_no_overlap.outputs[0].generation_logits[1, :],
288+
atol=0.2,
289+
rtol=0.2)
290+
291+
# compare logprobs of all generated tokens
292+
torch.testing.assert_close(extract_decode_logprobs(no_cg_no_overlap),
293+
extract_decode_logprobs(with_cg_no_overlap),
294+
atol=0.2,
295+
rtol=0.2)
296+
297+
# overlap scheduler should have no effect on all logits - low tolerance
298+
torch.testing.assert_close(
299+
with_cg_no_overlap.outputs[0].generation_logits,
300+
with_cg_with_overlap.outputs[0].generation_logits,
301+
atol=0.05,
302+
rtol=0.05)

0 commit comments

Comments
 (0)