1313from tensorrt_llm ._torch .metadata import KVCacheParams
1414from tensorrt_llm ._torch .model_config import ModelConfig
1515from tensorrt_llm ._torch .models .modeling_starcoder2 import Starcoder2ForCausalLM
16+ from tensorrt_llm ._torch .modules .layer_norm import LayerNorm
1617from tensorrt_llm ._torch .pyexecutor .resource_manager import KVCacheManager
1718from tensorrt_llm .bindings .executor import KvCacheConfig
1819from tensorrt_llm .mapping import Mapping
@@ -162,16 +163,24 @@ def test_starcoder2_allclose_to_hf(scenario: Scenario) -> None:
162163 # Create HuggingFace model from config with random weights
163164 hf_config = Starcoder2Config .from_dict (config_dict )
164165 hf_starcoder2 = HFStarcoder2ForCausalLM (hf_config )
165- hf_starcoder2 = hf_starcoder2 .to (dtype = torch .bfloat16 , device = "cuda" )
166+ hf_starcoder2 = hf_starcoder2 .to (dtype = torch .bfloat16 , device = "cuda" ). eval ()
166167
167168 dtype = torch .bfloat16
168169 device = torch .device ("cuda" )
169170
170171 # Build TRT-LLM model and copy the same random weights from HF model
171172 with torch .device (device ), default_dtype (dtype ):
172173 model_config = ModelConfig (pretrained_config = hf_config , attn_backend = backend )
173- starcoder2 = Starcoder2ForCausalLM (model_config ).to (dtype ).to (device )
174+ starcoder2 = Starcoder2ForCausalLM (model_config ).to (dtype ).to (device ). eval ()
174175 starcoder2 .load_weights (hf_starcoder2 .state_dict ())
176+
177+ # Convert LayerNorm random weights to FP32 for numerical stability
178+ for name , module in starcoder2 .named_modules ():
179+ if isinstance (module , LayerNorm ):
180+ if hasattr (module , 'weight' ) and module .weight is not None :
181+ module .weight .data = module .weight .data .to (torch .float32 )
182+ if hasattr (module , 'bias' ) and module .bias is not None :
183+ module .bias .data = module .bias .data .to (torch .float32 )
175184
176185 num_blocks = 1
177186 tokens_per_block = 128
@@ -190,7 +199,7 @@ def test_starcoder2_allclose_to_hf(scenario: Scenario) -> None:
190199 # Context phase (no CUDA graphs for prefill)
191200 input_ids = torch .tensor (
192201 [100 , 200 , 300 , 400 , 500 , 600 , 700 , 800 ],
193- dtype = torch .long ,
202+ dtype = torch .int ,
194203 device = device ,
195204 )
196205 num_cached_tokens_per_seq = [0 ]
@@ -200,7 +209,7 @@ def test_starcoder2_allclose_to_hf(scenario: Scenario) -> None:
200209 kv_cache_manager .add_dummy_requests (request_ids , token_nums )
201210
202211 attn_metadata = metadata_cls (
203- seq_lens = torch .tensor ([input_ids .size (- 1 )], dtype = torch .long ),
212+ seq_lens = torch .tensor ([input_ids .size (- 1 )], dtype = torch .int ),
204213 num_contexts = 1 ,
205214 kv_cache_params = KVCacheParams (
206215 use_cache = True ,
@@ -213,7 +222,7 @@ def test_starcoder2_allclose_to_hf(scenario: Scenario) -> None:
213222 prompt_lens = prompt_lens ,
214223 )
215224
216- position_ids = [torch .arange (0 , input_ids .size (- 1 ), dtype = torch .long )]
225+ position_ids = [torch .arange (0 , input_ids .size (- 1 ), dtype = torch .int )]
217226 position_ids = torch .cat (position_ids ).unsqueeze (0 ).cuda ()
218227
219228 with torch .inference_mode ():
@@ -231,11 +240,11 @@ def test_starcoder2_allclose_to_hf(scenario: Scenario) -> None:
231240 torch .testing .assert_close (logits , ref .logits [:, - 1 ].float (), atol = 0.4 , rtol = 0.4 )
232241
233242 # Generation phase (optionally with CUDA graphs)
234- gen_input_ids = torch .tensor ([900 ], dtype = torch .long , device = device )
243+ gen_input_ids = torch .tensor ([900 ], dtype = torch .int , device = device )
235244 num_cached_tokens_per_seq = [input_ids .size (- 1 )]
236245
237246 attn_metadata = metadata_cls (
238- seq_lens = torch .tensor ([gen_input_ids .size (- 1 )], dtype = torch .long ),
247+ seq_lens = torch .tensor ([gen_input_ids .size (- 1 )], dtype = torch .int ),
239248 num_contexts = 0 ,
240249 kv_cache_params = KVCacheParams (
241250 use_cache = True ,
@@ -250,7 +259,7 @@ def test_starcoder2_allclose_to_hf(scenario: Scenario) -> None:
250259
251260 gen_position_ids = [
252261 torch .arange (
253- input_ids .size (- 1 ), input_ids .size (- 1 ) + gen_input_ids .size (- 1 ), dtype = torch .long
262+ input_ids .size (- 1 ), input_ids .size (- 1 ) + gen_input_ids .size (- 1 ), dtype = torch .int
254263 )
255264 ]
256265 gen_position_ids = torch .cat (gen_position_ids ).unsqueeze (0 ).cuda ()
@@ -296,190 +305,9 @@ def test_starcoder2_allclose_to_hf(scenario: Scenario) -> None:
296305 past_key_values = ref .past_key_values ,
297306 use_cache = True ,
298307 )
299- torch .testing .assert_close (logits , ref .logits [:, - 1 ].float (), atol = 0.4 , rtol = 0.4 )
308+ torch .testing .assert_close (logits , ref .logits [:, - 1 ].float (), atol = 0.1 , rtol = 0.1 )
300309
301310 # Cleanup
302311 if graph_runner is not None :
303312 graph_runner .clear ()
304313 kv_cache_manager .shutdown ()
305-
306-
307- @pytest .mark .parametrize (
308- "scenario" ,
309- [
310- # Test token-level generation for different model sizes
311- Scenario (backend = "TRTLLM" , config_name = "3B" ),
312- Scenario (backend = "TRTLLM" , config_name = "7B" ),
313- Scenario (backend = "TRTLLM" , config_name = "15B" ),
314- ],
315- ids = str ,
316- )
317- @torch .no_grad ()
318- def test_starcoder2_generated_tokens_match_hf (scenario : Scenario ) -> None :
319- """
320- Compare generated tokens from TRT-LLM PyTorch backend to HuggingFace.
321- Uses randomly initialized models with identical weights.
322- """
323- backend = scenario .backend
324- config_name = scenario .config_name
325-
326- torch .random .manual_seed (0 )
327-
328- # Create config based on model size
329- config_mapping = {
330- "3B" : STARCODER2_3B_CONFIG ,
331- "7B" : STARCODER2_7B_CONFIG ,
332- "15B" : STARCODER2_15B_CONFIG ,
333- }
334- config_dict = deepcopy (config_mapping [config_name ])
335-
336- # Create HuggingFace model from config with random weights
337- hf_config = Starcoder2Config .from_dict (config_dict )
338- hf_starcoder2 = HFStarcoder2ForCausalLM (hf_config )
339- hf_starcoder2 = hf_starcoder2 .to (dtype = torch .bfloat16 , device = "cuda" )
340-
341- dtype = torch .bfloat16
342- device = torch .device ("cuda" )
343-
344- # Build TRT-LLM model and copy the same random weights from HF model
345- with torch .device (device ), default_dtype (dtype ):
346- model_config = ModelConfig (pretrained_config = hf_config , attn_backend = backend )
347- starcoder2 = Starcoder2ForCausalLM (model_config ).to (dtype ).to (device )
348- starcoder2 .load_weights (hf_starcoder2 .state_dict ())
349-
350- test_prompt = "def fibonacci(n):"
351- # Create a simple tokenizer for the test (just split by characters for simplicity)
352- # Use a fixed token mapping for deterministic testing
353- input_ids = torch .tensor (
354- [100 , 200 , 300 , 400 , 500 ], # Fixed token IDs for testing
355- dtype = torch .long ,
356- device = device ,
357- )
358-
359- # Setup KV cache for TRT-LLM generation
360- num_blocks = 2
361- tokens_per_block = 128
362- max_seq_len = num_blocks * tokens_per_block
363- batch_size = 1
364-
365- kv_cache_manager = get_kv_cache_manager (
366- dtype = dtype ,
367- config = hf_config ,
368- tokens_per_block = tokens_per_block ,
369- max_seq_len = max_seq_len ,
370- batch_size = batch_size ,
371- num_blocks = num_blocks ,
372- )
373-
374- # Generate tokens with TRT-LLM (manual generation loop)
375- max_new_tokens = 20
376- trt_output_ids = []
377- num_cached_tokens = 0
378- request_ids = [1 ]
379- prompt_lens = [input_ids .size (- 1 )]
380- metadata_cls = get_attention_backend (backend ).Metadata
381-
382- # Context phase - process initial prompt
383- token_nums = [input_ids .size (- 1 )]
384- kv_cache_manager .add_dummy_requests (request_ids , token_nums )
385-
386- attn_metadata = metadata_cls (
387- seq_lens = torch .tensor ([input_ids .size (- 1 )], dtype = torch .long ),
388- num_contexts = 1 ,
389- kv_cache_params = KVCacheParams (
390- use_cache = True ,
391- num_cached_tokens_per_seq = [0 ],
392- ),
393- kv_cache_manager = kv_cache_manager ,
394- request_ids = request_ids ,
395- prompt_lens = prompt_lens ,
396- max_num_requests = 1 ,
397- max_num_tokens = 8192 ,
398- )
399-
400- position_ids = torch .arange (
401- 0 , input_ids .size (- 1 ), dtype = torch .long , device = device
402- ).unsqueeze (0 )
403-
404- with torch .inference_mode ():
405- attn_metadata .prepare ()
406- logits = starcoder2 .forward (
407- input_ids = input_ids ,
408- position_ids = position_ids ,
409- attn_metadata = attn_metadata ,
410- )
411-
412- # Get first token
413- next_token_id = torch .argmax (logits , dim = - 1 ).item ()
414- trt_output_ids .append (next_token_id )
415- num_cached_tokens = input_ids .size (- 1 )
416-
417- # Generation phase - generate remaining tokens
418- for step in range (1 , max_new_tokens ):
419- gen_input_ids = torch .tensor ([next_token_id ], dtype = torch .long , device = device )
420-
421- attn_metadata = metadata_cls (
422- seq_lens = torch .tensor ([1 ], dtype = torch .long ),
423- num_contexts = 0 ,
424- kv_cache_params = KVCacheParams (
425- use_cache = True ,
426- num_cached_tokens_per_seq = [num_cached_tokens ],
427- ),
428- kv_cache_manager = kv_cache_manager ,
429- request_ids = request_ids ,
430- prompt_lens = prompt_lens ,
431- max_num_requests = 1 ,
432- max_num_tokens = 8192 ,
433- )
434-
435- gen_position_ids = torch .arange (
436- num_cached_tokens , num_cached_tokens + 1 , dtype = torch .long , device = device
437- ).unsqueeze (0 )
438-
439- with torch .inference_mode ():
440- attn_metadata .prepare ()
441- logits = starcoder2 .forward (
442- input_ids = gen_input_ids ,
443- position_ids = gen_position_ids ,
444- attn_metadata = attn_metadata ,
445- )
446-
447- # Greedy sampling: take argmax
448- next_token_id = torch .argmax (logits , dim = - 1 ).item ()
449- trt_output_ids .append (next_token_id )
450- num_cached_tokens += 1
451-
452- # Generate with HuggingFace for comparison (manual loop for consistency)
453- hf_output_ids = []
454- hf_past_key_values = None
455- hf_current_ids = input_ids .unsqueeze (0 )
456-
457- with torch .inference_mode ():
458- for step in range (max_new_tokens ):
459- hf_output = hf_starcoder2 .forward (
460- input_ids = hf_current_ids ,
461- past_key_values = hf_past_key_values ,
462- use_cache = True ,
463- )
464- # Greedy sampling: take argmax
465- next_token_id = torch .argmax (hf_output .logits [:, - 1 , :], dim = - 1 ).item ()
466- hf_output_ids .append (next_token_id )
467- hf_past_key_values = hf_output .past_key_values
468- hf_current_ids = torch .tensor ([[next_token_id ]], dtype = torch .long , device = device )
469-
470- # Compare outputs - both should match exactly with same random weights
471- min_len = min (len (trt_output_ids ), len (hf_output_ids ))
472- matches = sum (1 for i in range (min_len ) if trt_output_ids [i ] == hf_output_ids [i ])
473- match_ratio = matches / min_len if min_len > 0 else 0.0
474-
475- # Print for debugging
476- print (f"\n { config_name } /{ backend } TRT output tokens: { trt_output_ids } " )
477- print (f"{ config_name } /{ backend } HF output tokens: { hf_output_ids } " )
478- print (f"Match ratio: { match_ratio :.2%} ({ matches } /{ min_len } tokens)" )
479-
480- # Should match exactly with identical random weights
481- assert match_ratio == 1.0 , (
482- f"TRT-LLM and HF token outputs should match exactly: { match_ratio :.2%} match"
483- )
484-
485- kv_cache_manager .shutdown ()
0 commit comments