1818import pytest
1919import torch
2020from transformers .configuration_utils import PretrainedConfig
21+ from utils .llm_data import llm_models_root
22+ from utils .util import force_ampere
2123
2224from tensorrt_llm import LLM , SamplingParams
2325from tensorrt_llm ._torch .attention_backend .interface import AttentionMetadata
3133from tensorrt_llm ._torch .models .modeling_utils import (
3234 ModelConfig , register_auto_model , register_checkpoint_weight_loader ,
3335 register_config_loader )
36+ from tensorrt_llm .executor import RequestError
3437from tensorrt_llm .executor .result import CompletionOutput , GenerationResult
3538from tensorrt_llm .llmapi import CudaGraphConfig , KvCacheConfig
3639
@@ -263,11 +266,21 @@ def fixed_params():
263266
264267
265268@pytest .fixture (scope = "module" )
266- def llm (fixed_params , input_prompts ) :
269+ def model_kwargs (fixed_params ) -> dict [ str , Any ] :
267270 assert fixed_params [
268271 "max_beam_width" ] == 2 , "This test only works for a beam width of 2"
269- return LLM (
272+ return dict (
270273 model = _pl .Path ("dummy_path" ),
274+ checkpoint_loader = HfCheckpointLoader (
275+ weight_loader = DummyWeightLoader (),
276+ config_loader = DummyConfigLoader (),
277+ ),
278+ )
279+
280+
281+ def _build_llm (fixed_params , input_prompts , model_kwargs ):
282+ return LLM (
283+ ** model_kwargs ,
271284 kv_cache_config = KvCacheConfig (max_tokens = 10000 ),
272285 max_batch_size = fixed_params ["max_beam_width" ] * len (
273286 input_prompts
@@ -276,16 +289,18 @@ def llm(fixed_params, input_prompts):
276289 max_beam_width = fixed_params ["max_beam_width" ],
277290 disable_overlap_scheduler = True ,
278291 cuda_graph_config = None ,
279- checkpoint_loader = HfCheckpointLoader (weight_loader = DummyWeightLoader (),
280- config_loader = DummyConfigLoader ()))
292+ )
281293
282294
283295@pytest .fixture (scope = "module" )
284- def llm_cuda_graph (fixed_params , input_prompts ):
285- assert fixed_params [
286- "max_beam_width" ] == 2 , "This test only works for a beam width of 2"
296+ def llm (fixed_params , input_prompts , model_kwargs ):
297+ return _build_llm (fixed_params , input_prompts , model_kwargs )
298+
299+
300+ @pytest .fixture (scope = "module" )
301+ def llm_cuda_graph (fixed_params , input_prompts , model_kwargs ):
287302 return LLM (
288- model = _pl . Path ( "dummy_path" ) ,
303+ ** model_kwargs ,
289304 kv_cache_config = KvCacheConfig (max_tokens = 10000 ),
290305 max_batch_size = fixed_params ["max_beam_width" ] * len (
291306 input_prompts
@@ -295,8 +310,7 @@ def llm_cuda_graph(fixed_params, input_prompts):
295310 disable_overlap_scheduler = False ,
296311 cuda_graph_config = CudaGraphConfig (batch_sizes = [1 , 2 , 4 , 8 ],
297312 enable_padding = True ),
298- checkpoint_loader = HfCheckpointLoader (weight_loader = DummyWeightLoader (),
299- config_loader = DummyConfigLoader ()))
313+ )
300314
301315
302316def check_generation_logits (beam : CompletionOutput ,
@@ -473,5 +487,110 @@ def test_beam_search_output_shapes_cuda_graph_and_overlap(
473487 sampling_params )
474488
475489
490+ @force_ampere # Save H100 resource
491+ class TestParameterValidation :
492+ """Ensure that unsupported request parameters do not crash/hang the engine."""
493+
494+ @pytest .fixture (scope = "module" )
495+ @staticmethod
496+ def fixed_params ():
497+ return {"max_tokens" : 8 , "max_beam_width" : 4 }
498+
499+ @pytest .fixture (scope = "module" )
500+ @staticmethod
501+ def model_kwargs () -> dict [str , Any ]:
502+ root = llm_models_root ()
503+ assert root is not None
504+ return dict (model = root / "llama-models-v2" /
505+ "TinyLlama-1.1B-Chat-v1.0" , )
506+
507+ # NB: Class-level fixture overrides do not work without this
508+ @pytest .fixture (scope = "module" )
509+ @staticmethod
510+ def llm (fixed_params , input_prompts , model_kwargs ):
511+ return _build_llm (fixed_params , input_prompts , model_kwargs )
512+
513+ def _check_engine_responds (self , llm : LLM , input_prompts : list [str ],
514+ fixed_params : dict ):
515+ _ = llm .generate (input_prompts ,
516+ sampling_params = SamplingParams (
517+ max_tokens = fixed_params ["max_tokens" ],
518+ n = 1 ,
519+ best_of = fixed_params ["max_beam_width" ],
520+ use_beam_search = True ,
521+ end_id = - 1 ,
522+ ))
523+
524+ @pytest .mark .timeout (120 )
525+ @pytest .mark .threadleak (enabled = False )
526+ def test_use_beam_search_false (
527+ self ,
528+ llm : LLM ,
529+ input_prompts : list [str ],
530+ fixed_params : dict ,
531+ ):
532+ assert fixed_params ["max_beam_width" ] > 2
533+ with pytest .raises (
534+ ValueError ,
535+ match =
536+ ".*Greedy decoding in the LLM API does not allow multiple returns.*"
537+ ):
538+ _ = llm .generate (input_prompts ,
539+ sampling_params = SamplingParams (
540+ max_tokens = fixed_params ["max_tokens" ],
541+ n = 1 ,
542+ best_of = fixed_params ["max_beam_width" ],
543+ use_beam_search = False ,
544+ end_id = - 1 ,
545+ ))
546+ self ._check_engine_responds (llm , input_prompts , fixed_params )
547+
548+ @pytest .mark .timeout (120 )
549+ @pytest .mark .threadleak (enabled = False )
550+ def test_use_beam_search_ommitted (
551+ self ,
552+ llm : LLM ,
553+ input_prompts : list [str ],
554+ fixed_params : dict ,
555+ ):
556+ assert fixed_params ["max_beam_width" ] > 2
557+ with pytest .raises (
558+ ValueError ,
559+ match =
560+ ".*Greedy decoding in the LLM API does not allow multiple returns.*"
561+ ):
562+ _ = llm .generate (input_prompts ,
563+ sampling_params = SamplingParams (
564+ max_tokens = fixed_params ["max_tokens" ],
565+ n = 1 ,
566+ best_of = fixed_params ["max_beam_width" ],
567+ end_id = - 1 ,
568+ ))
569+ self ._check_engine_responds (llm , input_prompts , fixed_params )
570+
571+ @pytest .mark .timeout (120 )
572+ @pytest .mark .threadleak (enabled = False )
573+ def test_smaller_beam_width (
574+ self ,
575+ llm : LLM ,
576+ input_prompts : list [str ],
577+ fixed_params : dict ,
578+ ):
579+ assert fixed_params ["max_beam_width" ] > 2
580+ with pytest .raises (
581+ RequestError ,
582+ match = ".*Request beam width 2 is not equal to max_beam_width 4*"
583+ ):
584+ _ = llm .generate (input_prompts ,
585+ sampling_params = SamplingParams (
586+ max_tokens = fixed_params ["max_tokens" ],
587+ n = 1 ,
588+ best_of = 2 ,
589+ use_beam_search = True ,
590+ end_id = - 1 ,
591+ ))
592+ self ._check_engine_responds (llm , input_prompts , fixed_params )
593+
594+
476595if __name__ == "__main__" :
477596 pytest .main ([__file__ ])
0 commit comments