diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index ac869d765a5..71d34af1a64 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -31,7 +31,7 @@ GenerationResult, IterationResult, LoRARequest, PostprocWorkerConfig, PromptAdapterRequest) from ..executor.postproc_worker import PostprocParams -from ..executor.utils import (create_mpi_comm_session, +from ..executor.utils import (RequestError, create_mpi_comm_session, get_spawn_proxy_process_env) from ..inputs import (PromptInputs, create_input_processor, create_input_processor_with_hash, get_cache_salt_id, @@ -686,7 +686,7 @@ def _check_arguments(self, prompt_len: int, query_len: int, if self.args.backend == "pytorch" and not self.args.enable_chunked_prefill and not is_gen_only: max_num_tokens = self.args.max_num_tokens if max_num_tokens and prompt_len / self.args.parallel_config.cp_size + query_len > max_num_tokens: - raise ValueError( + raise RequestError( f"The sum of prompt length ({prompt_len/self.args.parallel_config.cp_size}), query length ({query_len}) should not exceed " f"max_num_tokens ({max_num_tokens})") return diff --git a/tests/unittest/llmapi/test_llm.py b/tests/unittest/llmapi/test_llm.py index f8ffe8fc7bd..87624b61b8e 100644 --- a/tests/unittest/llmapi/test_llm.py +++ b/tests/unittest/llmapi/test_llm.py @@ -2393,7 +2393,8 @@ def fail_path(): enable_chunked_prefill=False, fast_build=True) - with pytest.raises(ValueError): + # max_num_tokens validation now raises RequestError consistently + with pytest.raises(RequestError): output = llm.generate_async( "A " * build_config.max_num_tokens, sampling_params=sampling_params, @@ -2436,13 +2437,9 @@ def _test_llm_capture_request_error(pytorch_backend: bool, tp_size: int = 1): ) prompt = 'A ' * 65 # the minimum max_num_tokens is 64 - if pytorch_backend: - # pytorch backend will raise ValueError for max_num_tokens - with pytest.raises(ValueError): - llm.generate(prompt) - else: - with pytest.raises(RequestError): - llm.generate(prompt) + # Both backends now consistently raise RequestError for max_num_tokens validation + with pytest.raises(RequestError): + llm.generate(prompt) def test_llm_capture_request_error(): diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index b9702133152..86f48d31266 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -8,7 +8,7 @@ from tensorrt_llm import LLM from tensorrt_llm.disaggregated_params import DisaggregatedParams -from tensorrt_llm.executor import GenerationExecutorWorker +from tensorrt_llm.executor import GenerationExecutorWorker, RequestError from tensorrt_llm.executor.rpc_proxy import GenerationExecutorRpcProxy from tensorrt_llm.llmapi import CacheTransceiverConfig, KvCacheConfig from tensorrt_llm.llmapi.llm_args import NGramDecodingConfig, PeftCacheConfig @@ -830,10 +830,13 @@ def test_max_num_token_check(self): kv_cache_config=global_kvcache_config, max_num_tokens=100) - with pytest.raises(ValueError, - match="should not exceed max_num_tokens"): - ids = [random.randint(10, 100) for _ in range(101)] - llm.generate([ids]) + try: + with pytest.raises(RequestError, + match="should not exceed max_num_tokens"): + ids = [random.randint(10, 100) for _ in range(101)] + llm.generate([ids]) + finally: + llm.shutdown() class FailingExecutorWorker(GenerationExecutorWorker): @@ -962,10 +965,13 @@ def test_max_num_token_check(self): kv_cache_config=global_kvcache_config, max_num_tokens=100) - with pytest.raises(ValueError, - match="should not exceed max_num_tokens"): - ids = [random.randint(10, 100) for _ in range(101)] - llm.generate([ids]) + try: + with pytest.raises(RequestError, + match="should not exceed max_num_tokens"): + ids = [random.randint(10, 100) for _ in range(101)] + llm.generate([ids]) + finally: + llm.shutdown() @skip_ray