Skip to content

Commit cabd43a

Browse files
Address reviews 2
Signed-off-by: Yuan Tong <[email protected]>
1 parent dda46a5 commit cabd43a

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ class LogProbsState:
100100
@dataclass(kw_only=True)
101101
class LogProbsStateList:
102102
FloatState = list[list[list[float]]]
103-
IntState = list[list[list[float]]]
103+
IntState = list[list[list[int]]]
104104

105105
sampled_vals: FloatState
106106
sampled_indices: IntState
@@ -389,6 +389,10 @@ def _get_max_beam_width(request: LlmRequest) -> int:
389389
return max_beam_width
390390

391391

392+
def _request_sampling_params_cachable(params: UtilsSamplingParams) -> bool:
393+
return not params.use_beam_search
394+
395+
392396
def _request_get_sampling_params(request: LlmRequest) -> UtilsSamplingParams:
393397
sampling_config = request.sampling_config
394398
temperature = _unwrap_singleton(cast(Optional[list[float]], sampling_config.temperature))
@@ -409,10 +413,16 @@ def _request_get_sampling_params(request: LlmRequest) -> UtilsSamplingParams:
409413

410414

411415
def _request_strategy(request: LlmRequest, *, vocab_size: int) -> Strategy:
412-
if not hasattr(request, "py_sampling_strategy") or _get_max_beam_width(request) > 1:
413-
params = _request_get_sampling_params(request)
416+
# We try to cache the resolved strategy on the request object, as it's not cheap enough to
417+
# resolve it on every iteration.
418+
if hasattr(request, "py_sampling_strategy"):
419+
return request.py_sampling_strategy
420+
421+
params = _request_get_sampling_params(request)
422+
sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size)
423+
if _request_sampling_params_cachable(params):
414424
request.py_sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size)
415-
return request.py_sampling_strategy
425+
return sampling_strategy
416426

417427

418428
def _group_requests_by_strategy_key(

0 commit comments

Comments
 (0)