11import pytest
22import torch
33from utils .llm_data import llm_models_root
4- from utils .util import similar , skip_gpu_memory_less_than
4+ from utils .util import skip_fp8_pre_ada , skip_gpu_memory_less_than
55
66from tensorrt_llm import LLM
77from tensorrt_llm .llmapi import KvCacheConfig
88from tensorrt_llm .llmapi .llm import RequestOutput
9- from tensorrt_llm .llmapi .llm_args import CudaGraphConfig
9+ from tensorrt_llm .llmapi .llm_args import CudaGraphConfig , LoadFormat
1010from tensorrt_llm .sampling_params import SamplingParams
1111
1212
@@ -37,9 +37,15 @@ def create_nemotron_h_llm(model_folder,
3737 max_batch_size ,
3838 mamba_ssm_cache_dtype = None ,
3939 enable_chunked_prefill = False ,
40- max_num_tokens = 8192 ):
40+ max_num_tokens = 8192 ,
41+ load_format = None ):
4142 """Create LLM with specific overlap scheduler setting"""
4243 model_dir = f"{ llm_models_root (check = True )} /{ model_folder } "
44+ kwargs = {}
45+ if max_num_tokens is not None :
46+ kwargs ["max_num_tokens" ] = max_num_tokens
47+ if load_format is not None :
48+ kwargs ["load_format" ] = load_format
4349 return LLM (
4450 model = model_dir ,
4551 tensor_parallel_size = 1 ,
@@ -52,19 +58,71 @@ def create_nemotron_h_llm(model_folder,
5258 if mamba_ssm_cache_dtype is None else mamba_ssm_cache_dtype ),
5359 sampler_type = "TRTLLMSampler" ,
5460 enable_chunked_prefill = enable_chunked_prefill ,
55- ** ({} if max_num_tokens is None else {
56- "max_num_tokens" : max_num_tokens
57- }),
61+ ** kwargs ,
5862 )
5963
6064
65+ @pytest .mark .parametrize ("mamba_ssm_cache_dtype" , [None , "float32" ],
66+ ids = lambda n : f"mamba_ssm_cache_dtype:{ n } " )
67+ @pytest .mark .parametrize ("model_folder" , [
68+ pytest .param ("NVIDIA-Nemotron-3-Nano-30B-A3B-BF16" ,
69+ marks = skip_gpu_memory_less_than ((2 * 30 + 1 ) * 2 ** 30 )),
70+ pytest .param ("NVIDIA-Nemotron-3-Nano-30B-A3B-FP8" ,
71+ marks = skip_gpu_memory_less_than ((30 + 1 ) * 2 ** 30 )),
72+ ])
73+ def test_nemotron_h_sanity (mamba_ssm_cache_dtype , model_folder ):
74+ # Skip test if FP8 is not supported on the current architecture.
75+ use_fp8 = model_folder == "NVIDIA-Nemotron-3-Nano-30B-A3B-FP8"
76+ skip_fp8_pre_ada (use_fp8 )
77+
78+ torch .cuda .empty_cache ()
79+
80+ text_prompts = [
81+ "The future of AI is" ,
82+ "The president of the United States is" ,
83+ ]
84+ num_prompts = len (text_prompts )
85+
86+ with create_nemotron_h_llm (
87+ model_folder = model_folder ,
88+ use_cuda_graph = False ,
89+ disable_overlap_scheduler = False ,
90+ max_batch_size = num_prompts ,
91+ mamba_ssm_cache_dtype = mamba_ssm_cache_dtype ,
92+ load_format = LoadFormat .DUMMY ,
93+ ) as nemotron_h :
94+ sampling_params = SamplingParams (max_tokens = 9 ,
95+ temperature = 0.0 ,
96+ add_special_tokens = False ,
97+ return_context_logits = True ,
98+ return_generation_logits = True )
99+
100+ # Non-batching prefill sanity check.
101+ _ = [
102+ nemotron_h .generate (text_prompt , sampling_params )
103+ for text_prompt in text_prompts
104+ ]
105+
106+ # Batching prefill sanity check.
107+ results_batching = nemotron_h .generate (text_prompts , sampling_params )
108+ completions_batching = [
109+ result .outputs [0 ].text for result in results_batching
110+ ]
111+
112+ # Decoding sanity check.
113+ text_prompts_with_completions = [
114+ f"{ text_prompts [i ]} { completions_batching [i ]} "
115+ for i in range (num_prompts )
116+ ]
117+ sampling_params .max_tokens = 1
118+ nemotron_h .generate (text_prompts_with_completions , sampling_params )
119+
120+
61121@pytest .mark .parametrize ("mamba_ssm_cache_dtype" , [None , "float32" ],
62122 ids = lambda n : f"mamba_ssm_cache_dtype:{ n } " )
63123@pytest .mark .parametrize ("model_folder" , [
64124 pytest .param ("Nemotron-H-8B-Base-8K" ,
65125 marks = skip_gpu_memory_less_than ((2 * 8 + 1 ) * 2 ** 30 )),
66- pytest .param ("Nemotron-Nano-3-30B-A3.5B-dev-1024" ,
67- marks = skip_gpu_memory_less_than ((2 * 30 + 1 ) * 2 ** 30 )),
68126])
69127def test_nemotron_h_correctness (mamba_ssm_cache_dtype , model_folder ):
70128 torch .cuda .empty_cache ()
@@ -152,50 +210,6 @@ def test_nemotron_h_correctness(mamba_ssm_cache_dtype, model_folder):
152210 - 0.04291720315814018
153211 ])
154212 ]
155- elif model_folder == "Nemotron-Nano-3-30B-A3.5B-dev-1024" :
156-
157- expected_completions = [
158- " bright, with endless possibilities for innovation and growth" ,
159- " the head of state and head of government of" ,
160- ]
161-
162- # Copied from prefill_logprobs_no_batching[0] directly.
163- prefill_logprobs_ref_mcore = torch .tensor (
164- [- 8.5145 , - 0.8952 , - 2.3531 , - 1.6690 ])
165-
166- # reference logprobs from initial implementation (commit e4e42e0ec30227866ce30fc9c93d5e49352bb79c on single H200).
167- initial_impl_atol = 2.0
168- batching_atol = 2.0
169-
170- prefill_logprobs_ref_initial_no_batching = [
171- torch .tensor ([- 8.5145 , - 0.8952 , - 2.3531 , - 1.6690 ]),
172- torch .tensor ([- 9.9306 , - 1.4935 , - 0.4787 , - 1.4945 , - 0.0195 , - 1.5253 ])
173- ]
174- prefill_logprobs_ref_initial_with_batching = [
175- torch .tensor ([- 8.5221 , - 0.8114 , - 2.4334 , - 1.6909 ]),
176- torch .tensor ([- 9.9466 , - 1.5095 , - 0.5282 , - 1.4701 , - 0.0185 , - 1.4108 ])
177- ]
178-
179- decode_logprobs_ref_initial_no_batching = [
180- torch .tensor ([
181- - 9.2718e-01 , - 9.7786e-01 , - 7.5823e-01 , - 3.3243e-01 , - 8.7978e-01 ,
182- - 3.2046e-02 , - 9.5047e-01 , - 9.2678e-01 , - 2.5973e-04
183- ]),
184- torch .tensor ([
185- - 1.6836 , - 0.8289 , - 0.0063 , - 0.5166 , - 0.1798 , - 0.6075 , - 1.0987 ,
186- - 0.9075 , - 0.0025
187- ])
188- ]
189- decode_logprobs_ref_initial_with_batching = [
190- torch .tensor ([
191- - 9.0849e-01 , - 9.3238e-01 , - 8.2788e-01 , - 3.5542e-01 , - 9.0881e-01 ,
192- - 3.4794e-02 , - 9.4975e-01 , - 9.2631e-01 , - 2.4041e-04
193- ]),
194- torch .tensor ([
195- - 1.6331 , - 0.7666 , - 0.0063 , - 0.5110 , - 0.1617 , - 0.6578 , - 1.1073 ,
196- - 1.1447 , - 0.0024
197- ])
198- ]
199213 else :
200214 raise ValueError (f"Invalid model folder: { model_folder } " )
201215
@@ -255,14 +269,9 @@ def test_nemotron_h_correctness(mamba_ssm_cache_dtype, model_folder):
255269 atol = initial_impl_atol ,
256270 rtol = 0.0 )
257271
258- if model_folder == "Nemotron-H-8B-Base-8K" :
259- # compare expected completion
260- assert completions_batching [i ] == expected_completions [i ]
261- assert completions_no_batching [i ] == expected_completions [i ]
262- else :
263- assert similar (completions_batching [i ],
264- completions_no_batching [i ],
265- threshold = 0.5 )
272+ # compare expected completion
273+ assert completions_batching [i ] == expected_completions [i ]
274+ assert completions_no_batching [i ] == expected_completions [i ]
266275
267276 # compare decode logprobs with initial implementation
268277 torch .testing .assert_close (
0 commit comments