1414from tensorrt_llm ._torch .metadata import KVCacheParams
1515from tensorrt_llm ._torch .model_config import ModelConfig
1616from tensorrt_llm ._torch .models .modeling_starcoder2 import Starcoder2ForCausalLM
17- from tensorrt_llm ._torch .pyexecutor .cuda_graph_runner import CUDAGraphRunner
1817from tensorrt_llm ._torch .pyexecutor .resource_manager import KVCacheManager
1918from tensorrt_llm .bindings .executor import KvCacheConfig
2019from tensorrt_llm .mapping import Mapping
@@ -114,7 +113,7 @@ def get_kv_cache_manager(
114113 elif dtype == torch .bfloat16 :
115114 kv_cache_dtype = tensorrt_llm .bindings .DataType .BF16
116115 else :
117- raise ValueError ("Invalid dtype" )
116+ raise ValueError (f "Invalid dtype: { dtype } " )
118117
119118 mapping = Mapping (world_size = 1 , tp_size = 1 , rank = 0 )
120119 kv_cache_config = KvCacheConfig (
@@ -160,7 +159,7 @@ def test_starcoder2_sanity(self):
160159
161160 input_ids = torch .tensor (
162161 [100 , 200 , 300 , 400 , 500 , 600 , 700 , 800 ],
163- dtype = torch .int ,
162+ dtype = torch .long ,
164163 device = device ,
165164 )
166165
@@ -188,7 +187,7 @@ def test_starcoder2_sanity(self):
188187
189188 metadata_cls = get_attention_backend (model_config .attn_backend ).Metadata
190189 attn_metadata = metadata_cls (
191- seq_lens = torch .tensor (sequence_lengths , dtype = torch .int ),
190+ seq_lens = torch .tensor (sequence_lengths , dtype = torch .long ),
192191 num_contexts = len (context_sequence_lengths ),
193192 kv_cache_params = KVCacheParams (
194193 use_cache = True ,
@@ -302,7 +301,7 @@ def test_starcoder2_allclose_to_hf(self, scenario: Scenario) -> None:
302301 # Context phase (no CUDA graphs for prefill)
303302 input_ids = torch .tensor (
304303 [100 , 200 , 300 , 400 , 500 , 600 , 700 , 800 ],
305- dtype = torch .int32 ,
304+ dtype = torch .long ,
306305 device = device ,
307306 )
308307 num_cached_tokens_per_seq = [0 ]
@@ -312,7 +311,7 @@ def test_starcoder2_allclose_to_hf(self, scenario: Scenario) -> None:
312311 kv_cache_manager .add_dummy_requests (request_ids , token_nums )
313312
314313 attn_metadata = metadata_cls (
315- seq_lens = torch .tensor ([input_ids .size (- 1 )], dtype = torch .int ),
314+ seq_lens = torch .tensor ([input_ids .size (- 1 )], dtype = torch .long ),
316315 num_contexts = 1 ,
317316 kv_cache_params = KVCacheParams (
318317 use_cache = True ,
@@ -325,7 +324,7 @@ def test_starcoder2_allclose_to_hf(self, scenario: Scenario) -> None:
325324 prompt_lens = prompt_lens ,
326325 )
327326
328- position_ids = [torch .arange (0 , input_ids .size (- 1 ), dtype = torch .int32 )]
327+ position_ids = [torch .arange (0 , input_ids .size (- 1 ), dtype = torch .long )]
329328 position_ids = torch .cat (position_ids ).unsqueeze (0 ).cuda ()
330329
331330 with torch .inference_mode ():
@@ -343,11 +342,11 @@ def test_starcoder2_allclose_to_hf(self, scenario: Scenario) -> None:
343342 torch .testing .assert_close (logits , ref .logits [:, - 1 ].float (), atol = 0.4 , rtol = 0.4 )
344343
345344 # Generation phase (optionally with CUDA graphs)
346- gen_input_ids = torch .tensor ([900 ], dtype = torch .int32 , device = device )
345+ gen_input_ids = torch .tensor ([900 ], dtype = torch .long , device = device )
347346 num_cached_tokens_per_seq = [input_ids .size (- 1 )]
348347
349348 attn_metadata = metadata_cls (
350- seq_lens = torch .tensor ([gen_input_ids .size (- 1 )], dtype = torch .int ),
349+ seq_lens = torch .tensor ([gen_input_ids .size (- 1 )], dtype = torch .long ),
351350 num_contexts = 0 ,
352351 kv_cache_params = KVCacheParams (
353352 use_cache = True ,
@@ -362,18 +361,17 @@ def test_starcoder2_allclose_to_hf(self, scenario: Scenario) -> None:
362361
363362 gen_position_ids = [
364363 torch .arange (
365- input_ids .size (- 1 ), input_ids .size (- 1 ) + gen_input_ids .size (- 1 ), dtype = torch .int32
364+ input_ids .size (- 1 ), input_ids .size (- 1 ) + gen_input_ids .size (- 1 ), dtype = torch .long
366365 )
367366 ]
368367 gen_position_ids = torch .cat (gen_position_ids ).unsqueeze (0 ).cuda ()
369368
370369 # Setup CUDA graph runner if requested
371370 graph_runner = None
372371 if use_cuda_graph :
373- from _torch .helpers import create_mock_engine
372+ from _torch .helpers import create_mock_cuda_graph_runner
374373
375- mock_engine = create_mock_engine (1 )
376- graph_runner = CUDAGraphRunner (mock_engine )
374+ graph_runner = create_mock_cuda_graph_runner (1 )
377375 attn_metadata = attn_metadata .create_cuda_graph_metadata (1 )
378376
379377 # Run generation phase
@@ -476,7 +474,7 @@ def test_starcoder2_generated_tokens_match_hf(self, scenario: Scenario) -> None:
476474 # Encode test prompt
477475 input_ids = torch .tensor (
478476 tokenizer .encode (test_prompt ),
479- dtype = torch .int32 ,
477+ dtype = torch .long ,
480478 device = device ,
481479 )
482480
@@ -508,7 +506,7 @@ def test_starcoder2_generated_tokens_match_hf(self, scenario: Scenario) -> None:
508506 kv_cache_manager .add_dummy_requests (request_ids , token_nums )
509507
510508 attn_metadata = metadata_cls (
511- seq_lens = torch .tensor ([input_ids .size (- 1 )], dtype = torch .int ),
509+ seq_lens = torch .tensor ([input_ids .size (- 1 )], dtype = torch .long ),
512510 num_contexts = 1 ,
513511 kv_cache_params = KVCacheParams (
514512 use_cache = True ,
@@ -522,7 +520,7 @@ def test_starcoder2_generated_tokens_match_hf(self, scenario: Scenario) -> None:
522520 )
523521
524522 position_ids = torch .arange (
525- 0 , input_ids .size (- 1 ), dtype = torch .int32 , device = device
523+ 0 , input_ids .size (- 1 ), dtype = torch .long , device = device
526524 ).unsqueeze (0 )
527525
528526 with torch .inference_mode ():
@@ -540,10 +538,10 @@ def test_starcoder2_generated_tokens_match_hf(self, scenario: Scenario) -> None:
540538
541539 # Generation phase - generate remaining tokens
542540 for step in range (1 , max_new_tokens ):
543- gen_input_ids = torch .tensor ([next_token_id ], dtype = torch .int32 , device = device )
541+ gen_input_ids = torch .tensor ([next_token_id ], dtype = torch .long , device = device )
544542
545543 attn_metadata = metadata_cls (
546- seq_lens = torch .tensor ([1 ], dtype = torch .int ),
544+ seq_lens = torch .tensor ([1 ], dtype = torch .long ),
547545 num_contexts = 0 ,
548546 kv_cache_params = KVCacheParams (
549547 use_cache = True ,
@@ -557,7 +555,7 @@ def test_starcoder2_generated_tokens_match_hf(self, scenario: Scenario) -> None:
557555 )
558556
559557 gen_position_ids = torch .arange (
560- num_cached_tokens , num_cached_tokens + 1 , dtype = torch .int32 , device = device
558+ num_cached_tokens , num_cached_tokens + 1 , dtype = torch .long , device = device
561559 ).unsqueeze (0 )
562560
563561 with torch .inference_mode ():
0 commit comments