55from tensorrt_llm import LLM
66from tensorrt_llm .llmapi import KvCacheConfig
77from tensorrt_llm .llmapi .llm import RequestOutput
8+ from tensorrt_llm .llmapi .llm_args import CudaGraphConfig
89from tensorrt_llm .sampling_params import SamplingParams
910
1011
@@ -28,25 +29,36 @@ def extract_decode_logprobs(result: RequestOutput,
2829 return get_logprobs (token_ids , logits )
2930
3031
32+ def create_nemotron_h_llm (use_cuda_graph , disable_overlap_scheduler ,
33+ max_batch_size ):
34+ """Create LLM with specific overlap scheduler setting"""
35+ model_dir = f"{ llm_models_root (check = True )} /Nemotron-H-8B-Base-8K"
36+ return LLM (
37+ model = model_dir ,
38+ tensor_parallel_size = 1 ,
39+ max_batch_size = max_batch_size ,
40+ cuda_graph_config = CudaGraphConfig () if use_cuda_graph else None ,
41+ disable_overlap_scheduler = disable_overlap_scheduler ,
42+ kv_cache_config = KvCacheConfig (enable_block_reuse = False ),
43+ enable_trtllm_sampler = True ,
44+ )
45+
46+
3147@skip_gpu_memory_less_than (
3248 (2 * 8 + 1 ) * 2 ** 30 ) # 8B, bf16, plus 1 GB for good measure
3349def test_nemotron_h_correctness ():
3450 # This test is close to memory limit on A30 (with 24GB), so empty cache first
3551 torch .cuda .empty_cache ()
3652
37- model_dir = f"{ llm_models_root (check = True )} /Nemotron-H-8B-Base-8K"
3853 text_prompts = [
3954 "The future of AI is" ,
4055 "The president of the United States is" ,
4156 ]
4257 num_prompts = len (text_prompts )
4358
44- nemotron_h = LLM (
45- model = model_dir ,
46- max_batch_size = num_prompts ,
47- kv_cache_config = KvCacheConfig (enable_block_reuse = False ),
48- enable_trtllm_sampler = True ,
49- )
59+ nemotron_h = create_nemotron_h_llm (use_cuda_graph = False ,
60+ disable_overlap_scheduler = False ,
61+ max_batch_size = num_prompts )
5062
5163 expected_completions = [
5264 " bright, with endless possibilities for innovation and growth" ,
@@ -223,3 +235,68 @@ def test_nemotron_h_correctness():
223235
224236 finally :
225237 nemotron_h .shutdown ()
238+
239+
240+ def test_nemotron_h_cuda_graph_overlap_scheduler ():
241+ prompts = [
242+ "Tell me something I don't know about the future of AI" ,
243+ "The president of the United States is" ,
244+ "The capital of France is" ,
245+ "Hello, this is a beautiful day and I'm eager to start my day and" ,
246+ ]
247+ sampling_config = SamplingParams (max_tokens = 12 ,
248+ temperature = 0.0 ,
249+ return_generation_logits = True )
250+
251+ # Test without cg and overlap scheduler disabled
252+ with create_nemotron_h_llm (use_cuda_graph = False ,
253+ disable_overlap_scheduler = True ,
254+ max_batch_size = 16 ) as llm :
255+ outputs_no_cg_no_overlap = llm .generate (prompts ,
256+ sampling_params = sampling_config ,
257+ use_tqdm = True )
258+
259+ # Test with cg and overlap scheduler disabled
260+ with create_nemotron_h_llm (use_cuda_graph = True ,
261+ disable_overlap_scheduler = True ,
262+ max_batch_size = 16 ) as llm :
263+ outputs_with_cg_no_overlap = llm .generate (
264+ prompts , sampling_params = sampling_config , use_tqdm = True )
265+
266+ # Test with cg and overlap scheduler enabled
267+ with create_nemotron_h_llm (use_cuda_graph = True ,
268+ disable_overlap_scheduler = False ,
269+ max_batch_size = 16 ) as llm :
270+ outputs_with_cg_with_overlap = llm .generate (
271+ prompts , sampling_params = sampling_config , use_tqdm = True )
272+
273+ # Verify outputs are consistent
274+ for (no_cg_no_overlap , with_cg_no_overlap ,
275+ with_cg_with_overlap ) in zip (outputs_no_cg_no_overlap ,
276+ outputs_with_cg_no_overlap ,
277+ outputs_with_cg_with_overlap ):
278+
279+ assert (no_cg_no_overlap .outputs [0 ].text ==
280+ with_cg_no_overlap .outputs [0 ].text )
281+ assert (with_cg_no_overlap .outputs [0 ].text ==
282+ with_cg_with_overlap .outputs [0 ].text )
283+
284+ # similar to other unittests comparing with / without CG, compare logits of first generation step (2nd generated token)
285+ torch .testing .assert_close (
286+ no_cg_no_overlap .outputs [0 ].generation_logits [1 , :],
287+ with_cg_no_overlap .outputs [0 ].generation_logits [1 , :],
288+ atol = 0.2 ,
289+ rtol = 0.2 )
290+
291+ # compare logprobs of all generated tokens
292+ torch .testing .assert_close (extract_decode_logprobs (no_cg_no_overlap ),
293+ extract_decode_logprobs (with_cg_no_overlap ),
294+ atol = 0.2 ,
295+ rtol = 0.2 )
296+
297+ # overlap scheduler should have no effect on all logits - low tolerance
298+ torch .testing .assert_close (
299+ with_cg_no_overlap .outputs [0 ].generation_logits ,
300+ with_cg_with_overlap .outputs [0 ].generation_logits ,
301+ atol = 0.05 ,
302+ rtol = 0.05 )
0 commit comments