@@ -100,7 +100,7 @@ class LogProbsState:
100100@dataclass (kw_only = True )
101101class LogProbsStateList :
102102 FloatState = list [list [list [float ]]]
103- IntState = list [list [list [float ]]]
103+ IntState = list [list [list [int ]]]
104104
105105 sampled_vals : FloatState
106106 sampled_indices : IntState
@@ -389,6 +389,10 @@ def _get_max_beam_width(request: LlmRequest) -> int:
389389 return max_beam_width
390390
391391
392+ def _request_sampling_params_cachable (params : UtilsSamplingParams ) -> bool :
393+ return not params .use_beam_search
394+
395+
392396def _request_get_sampling_params (request : LlmRequest ) -> UtilsSamplingParams :
393397 sampling_config = request .sampling_config
394398 temperature = _unwrap_singleton (cast (Optional [list [float ]], sampling_config .temperature ))
@@ -409,10 +413,16 @@ def _request_get_sampling_params(request: LlmRequest) -> UtilsSamplingParams:
409413
410414
411415def _request_strategy (request : LlmRequest , * , vocab_size : int ) -> Strategy :
412- if not hasattr (request , "py_sampling_strategy" ) or _get_max_beam_width (request ) > 1 :
413- params = _request_get_sampling_params (request )
416+ # We try to cache the resolved strategy on the request object, as it's not cheap enough to
417+ # resolve it on every iteration.
418+ if hasattr (request , "py_sampling_strategy" ):
419+ return request .py_sampling_strategy
420+
421+ params = _request_get_sampling_params (request )
422+ sampling_strategy = resolve_sampling_strategy (params , vocab_size = vocab_size )
423+ if _request_sampling_params_cachable (params ):
414424 request .py_sampling_strategy = resolve_sampling_strategy (params , vocab_size = vocab_size )
415- return request . py_sampling_strategy
425+ return sampling_strategy
416426
417427
418428def _group_requests_by_strategy_key (
0 commit comments