Skip to content

Commit c99faae

Browse files
authored
[#9760][fix] Use RequestError for validation errors to prevent engine shutdown (#9761)
Signed-off-by: [email protected] <[email protected]>
1 parent 01083b5 commit c99faae

File tree

3 files changed

+22
-19
lines changed

3 files changed

+22
-19
lines changed

tensorrt_llm/llmapi/llm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
GenerationResult, IterationResult, LoRARequest,
3232
PostprocWorkerConfig, PromptAdapterRequest)
3333
from ..executor.postproc_worker import PostprocParams
34-
from ..executor.utils import (create_mpi_comm_session,
34+
from ..executor.utils import (RequestError, create_mpi_comm_session,
3535
get_spawn_proxy_process_env)
3636
from ..inputs import (PromptInputs, create_input_processor,
3737
create_input_processor_with_hash, get_cache_salt_id,
@@ -686,7 +686,7 @@ def _check_arguments(self, prompt_len: int, query_len: int,
686686
if self.args.backend == "pytorch" and not self.args.enable_chunked_prefill and not is_gen_only:
687687
max_num_tokens = self.args.max_num_tokens
688688
if max_num_tokens and prompt_len / self.args.parallel_config.cp_size + query_len > max_num_tokens:
689-
raise ValueError(
689+
raise RequestError(
690690
f"The sum of prompt length ({prompt_len/self.args.parallel_config.cp_size}), query length ({query_len}) should not exceed "
691691
f"max_num_tokens ({max_num_tokens})")
692692
return

tests/unittest/llmapi/test_llm.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2393,7 +2393,8 @@ def fail_path():
23932393
enable_chunked_prefill=False,
23942394
fast_build=True)
23952395

2396-
with pytest.raises(ValueError):
2396+
# max_num_tokens validation now raises RequestError consistently
2397+
with pytest.raises(RequestError):
23972398
output = llm.generate_async(
23982399
"A " * build_config.max_num_tokens,
23992400
sampling_params=sampling_params,
@@ -2436,13 +2437,9 @@ def _test_llm_capture_request_error(pytorch_backend: bool, tp_size: int = 1):
24362437
)
24372438

24382439
prompt = 'A ' * 65 # the minimum max_num_tokens is 64
2439-
if pytorch_backend:
2440-
# pytorch backend will raise ValueError for max_num_tokens
2441-
with pytest.raises(ValueError):
2442-
llm.generate(prompt)
2443-
else:
2444-
with pytest.raises(RequestError):
2445-
llm.generate(prompt)
2440+
# Both backends now consistently raise RequestError for max_num_tokens validation
2441+
with pytest.raises(RequestError):
2442+
llm.generate(prompt)
24462443

24472444

24482445
def test_llm_capture_request_error():

tests/unittest/llmapi/test_llm_pytorch.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from tensorrt_llm import LLM
1010
from tensorrt_llm.disaggregated_params import DisaggregatedParams
11-
from tensorrt_llm.executor import GenerationExecutorWorker
11+
from tensorrt_llm.executor import GenerationExecutorWorker, RequestError
1212
from tensorrt_llm.executor.rpc_proxy import GenerationExecutorRpcProxy
1313
from tensorrt_llm.llmapi import CacheTransceiverConfig, KvCacheConfig
1414
from tensorrt_llm.llmapi.llm_args import NGramDecodingConfig, PeftCacheConfig
@@ -830,10 +830,13 @@ def test_max_num_token_check(self):
830830
kv_cache_config=global_kvcache_config,
831831
max_num_tokens=100)
832832

833-
with pytest.raises(ValueError,
834-
match="should not exceed max_num_tokens"):
835-
ids = [random.randint(10, 100) for _ in range(101)]
836-
llm.generate([ids])
833+
try:
834+
with pytest.raises(RequestError,
835+
match="should not exceed max_num_tokens"):
836+
ids = [random.randint(10, 100) for _ in range(101)]
837+
llm.generate([ids])
838+
finally:
839+
llm.shutdown()
837840

838841

839842
class FailingExecutorWorker(GenerationExecutorWorker):
@@ -962,10 +965,13 @@ def test_max_num_token_check(self):
962965
kv_cache_config=global_kvcache_config,
963966
max_num_tokens=100)
964967

965-
with pytest.raises(ValueError,
966-
match="should not exceed max_num_tokens"):
967-
ids = [random.randint(10, 100) for _ in range(101)]
968-
llm.generate([ids])
968+
try:
969+
with pytest.raises(RequestError,
970+
match="should not exceed max_num_tokens"):
971+
ids = [random.randint(10, 100) for _ in range(101)]
972+
llm.generate([ids])
973+
finally:
974+
llm.shutdown()
969975

970976

971977
@skip_ray

0 commit comments

Comments
 (0)