Skip to content

Commit 095b686

Browse files
authored
[TRTLLM-8650][fix] beam search request validation (#8433) (#9228)
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
1 parent 8cd3b49 commit 095b686

File tree

4 files changed

+148
-27
lines changed

4 files changed

+148
-27
lines changed

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def _fetch_and_process_requests(
310310
new_requests)
311311

312312
# Validate and filter requests
313-
new_requests = self._validate_and_filter_requests(new_requests)
313+
new_requests = self._handle_special_queue_items(new_requests)
314314

315315
# Attach Python objects to requests
316316
if py_request_objects and (self.dist.tp_size > 1
@@ -482,11 +482,11 @@ def _handle_request_broadcasting(self,
482482

483483
return new_requests, py_request_objects
484484

485-
def _validate_and_filter_requests(
485+
def _handle_special_queue_items(
486486
self,
487487
new_requests: List[RequestQueueItem]) -> List[RequestQueueItem]:
488-
"""Validate and filter requests, handling shutdown signals."""
489-
valid_new_requests = []
488+
"""Handle special signals."""
489+
accepted_new_requests = []
490490
for idx, req_item in enumerate(new_requests):
491491
if req_item.is_shutdown_request:
492492
self.is_shutdown = True
@@ -499,17 +499,9 @@ def _validate_and_filter_requests(
499499
self.request_accumulated.extend(new_requests[idx + 1:])
500500
break
501501
else:
502-
valid_new_requests.append(req_item)
502+
accepted_new_requests.append(req_item)
503503

504-
# Check beam width validation
505-
for req_item in valid_new_requests:
506-
if req_item.request and hasattr(req_item.request,
507-
'sampling_config'):
508-
assert req_item.request.sampling_config.beam_width == self.max_beam_width, \
509-
f"Request beam width {req_item.request.sampling_config.beam_width} " \
510-
f"is not equal to max_beam_width {self.max_beam_width}. This is not supported!"
511-
512-
return valid_new_requests
504+
return accepted_new_requests
513505

514506
def _balance_requests_across_ranks(
515507
self, new_requests: List[RequestQueueItem],

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,6 +1607,16 @@ def _forward_step_inter_pp(self, scheduled_batch) -> SampleState:
16071607
)
16081608

16091609
def _validate_request(self, request: LlmRequest):
1610+
# Validate beam width
1611+
sampling_config = request.sampling_config
1612+
if sampling_config is not None:
1613+
if sampling_config.beam_width != self.max_beam_width:
1614+
raise ValueError(
1615+
f"Request beam width {sampling_config.beam_width} "
1616+
f"is not equal to max_beam_width {self.max_beam_width}. This is not supported!"
1617+
)
1618+
1619+
# Check token ID ranges
16101620
if isinstance(self.model_engine.model, DecoderModelForCausalLM):
16111621
# Only skip token‐range checks for Llama4 when the request has multimodal data
16121622
from ..models.modeling_llama import Llama4ForConditionalGeneration

tests/unittest/_torch/executor/test_executor_request_queue.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -475,8 +475,8 @@ def test_get_from_waiting_queue_edge_cases(executor_queue, queue_size,
475475
assert len(executor_queue.waiting_queue) == expected_remaining
476476

477477

478-
def test_validate_and_filter_requests(executor_queue):
479-
"""Test request validation and filtering."""
478+
def test_handle_special_queue_items(executor_queue):
479+
"""Test special queue item handling."""
480480
# Create a mock request without sampling_config to avoid beam validation
481481
mock_request = Mock()
482482
delattr(mock_request, 'sampling_config') if hasattr(
@@ -488,7 +488,7 @@ def test_validate_and_filter_requests(executor_queue):
488488

489489
requests = [normal_req, cancel_req, shutdown_req]
490490

491-
valid_requests = executor_queue._validate_and_filter_requests(requests)
491+
valid_requests = executor_queue._handle_special_queue_items(requests)
492492

493493
assert len(valid_requests) == 1
494494
assert valid_requests[0] == normal_req

tests/unittest/_torch/sampler/test_beam_search.py

Lines changed: 129 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import pytest
1919
import torch
2020
from transformers.configuration_utils import PretrainedConfig
21+
from utils.llm_data import llm_models_root
22+
from utils.util import force_ampere
2123

2224
from tensorrt_llm import LLM, SamplingParams
2325
from tensorrt_llm._torch.attention_backend.interface import AttentionMetadata
@@ -31,6 +33,7 @@
3133
from tensorrt_llm._torch.models.modeling_utils import (
3234
ModelConfig, register_auto_model, register_checkpoint_weight_loader,
3335
register_config_loader)
36+
from tensorrt_llm.executor import RequestError
3437
from tensorrt_llm.executor.result import CompletionOutput, GenerationResult
3538
from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig
3639

@@ -263,11 +266,21 @@ def fixed_params():
263266

264267

265268
@pytest.fixture(scope="module")
266-
def llm(fixed_params, input_prompts):
269+
def model_kwargs(fixed_params) -> dict[str, Any]:
267270
assert fixed_params[
268271
"max_beam_width"] == 2, "This test only works for a beam width of 2"
269-
return LLM(
272+
return dict(
270273
model=_pl.Path("dummy_path"),
274+
checkpoint_loader=HfCheckpointLoader(
275+
weight_loader=DummyWeightLoader(),
276+
config_loader=DummyConfigLoader(),
277+
),
278+
)
279+
280+
281+
def _build_llm(fixed_params, input_prompts, model_kwargs):
282+
return LLM(
283+
**model_kwargs,
271284
kv_cache_config=KvCacheConfig(max_tokens=10000),
272285
max_batch_size=fixed_params["max_beam_width"] * len(
273286
input_prompts
@@ -276,16 +289,18 @@ def llm(fixed_params, input_prompts):
276289
max_beam_width=fixed_params["max_beam_width"],
277290
disable_overlap_scheduler=True,
278291
cuda_graph_config=None,
279-
checkpoint_loader=HfCheckpointLoader(weight_loader=DummyWeightLoader(),
280-
config_loader=DummyConfigLoader()))
292+
)
281293

282294

283295
@pytest.fixture(scope="module")
284-
def llm_cuda_graph(fixed_params, input_prompts):
285-
assert fixed_params[
286-
"max_beam_width"] == 2, "This test only works for a beam width of 2"
296+
def llm(fixed_params, input_prompts, model_kwargs):
297+
return _build_llm(fixed_params, input_prompts, model_kwargs)
298+
299+
300+
@pytest.fixture(scope="module")
301+
def llm_cuda_graph(fixed_params, input_prompts, model_kwargs):
287302
return LLM(
288-
model=_pl.Path("dummy_path"),
303+
**model_kwargs,
289304
kv_cache_config=KvCacheConfig(max_tokens=10000),
290305
max_batch_size=fixed_params["max_beam_width"] * len(
291306
input_prompts
@@ -295,8 +310,7 @@ def llm_cuda_graph(fixed_params, input_prompts):
295310
disable_overlap_scheduler=False,
296311
cuda_graph_config=CudaGraphConfig(batch_sizes=[1, 2, 4, 8],
297312
enable_padding=True),
298-
checkpoint_loader=HfCheckpointLoader(weight_loader=DummyWeightLoader(),
299-
config_loader=DummyConfigLoader()))
313+
)
300314

301315

302316
def check_generation_logits(beam: CompletionOutput,
@@ -473,5 +487,110 @@ def test_beam_search_output_shapes_cuda_graph_and_overlap(
473487
sampling_params)
474488

475489

490+
@force_ampere # Save H100 resource
491+
class TestParameterValidation:
492+
"""Ensure that unsupported request parameters do not crash/hang the engine."""
493+
494+
@pytest.fixture(scope="module")
495+
@staticmethod
496+
def fixed_params():
497+
return {"max_tokens": 8, "max_beam_width": 4}
498+
499+
@pytest.fixture(scope="module")
500+
@staticmethod
501+
def model_kwargs() -> dict[str, Any]:
502+
root = llm_models_root()
503+
assert root is not None
504+
return dict(model=root / "llama-models-v2" /
505+
"TinyLlama-1.1B-Chat-v1.0", )
506+
507+
# NB: Class-level fixture overrides do not work without this
508+
@pytest.fixture(scope="module")
509+
@staticmethod
510+
def llm(fixed_params, input_prompts, model_kwargs):
511+
return _build_llm(fixed_params, input_prompts, model_kwargs)
512+
513+
def _check_engine_responds(self, llm: LLM, input_prompts: list[str],
514+
fixed_params: dict):
515+
_ = llm.generate(input_prompts,
516+
sampling_params=SamplingParams(
517+
max_tokens=fixed_params["max_tokens"],
518+
n=1,
519+
best_of=fixed_params["max_beam_width"],
520+
use_beam_search=True,
521+
end_id=-1,
522+
))
523+
524+
@pytest.mark.timeout(120)
525+
@pytest.mark.threadleak(enabled=False)
526+
def test_use_beam_search_false(
527+
self,
528+
llm: LLM,
529+
input_prompts: list[str],
530+
fixed_params: dict,
531+
):
532+
assert fixed_params["max_beam_width"] > 2
533+
with pytest.raises(
534+
ValueError,
535+
match=
536+
".*Greedy decoding in the LLM API does not allow multiple returns.*"
537+
):
538+
_ = llm.generate(input_prompts,
539+
sampling_params=SamplingParams(
540+
max_tokens=fixed_params["max_tokens"],
541+
n=1,
542+
best_of=fixed_params["max_beam_width"],
543+
use_beam_search=False,
544+
end_id=-1,
545+
))
546+
self._check_engine_responds(llm, input_prompts, fixed_params)
547+
548+
@pytest.mark.timeout(120)
549+
@pytest.mark.threadleak(enabled=False)
550+
def test_use_beam_search_ommitted(
551+
self,
552+
llm: LLM,
553+
input_prompts: list[str],
554+
fixed_params: dict,
555+
):
556+
assert fixed_params["max_beam_width"] > 2
557+
with pytest.raises(
558+
ValueError,
559+
match=
560+
".*Greedy decoding in the LLM API does not allow multiple returns.*"
561+
):
562+
_ = llm.generate(input_prompts,
563+
sampling_params=SamplingParams(
564+
max_tokens=fixed_params["max_tokens"],
565+
n=1,
566+
best_of=fixed_params["max_beam_width"],
567+
end_id=-1,
568+
))
569+
self._check_engine_responds(llm, input_prompts, fixed_params)
570+
571+
@pytest.mark.timeout(120)
572+
@pytest.mark.threadleak(enabled=False)
573+
def test_smaller_beam_width(
574+
self,
575+
llm: LLM,
576+
input_prompts: list[str],
577+
fixed_params: dict,
578+
):
579+
assert fixed_params["max_beam_width"] > 2
580+
with pytest.raises(
581+
RequestError,
582+
match=".*Request beam width 2 is not equal to max_beam_width 4*"
583+
):
584+
_ = llm.generate(input_prompts,
585+
sampling_params=SamplingParams(
586+
max_tokens=fixed_params["max_tokens"],
587+
n=1,
588+
best_of=2,
589+
use_beam_search=True,
590+
end_id=-1,
591+
))
592+
self._check_engine_responds(llm, input_prompts, fixed_params)
593+
594+
476595
if __name__ == "__main__":
477596
pytest.main([__file__])

0 commit comments

Comments
 (0)