Skip to content

Commit 899fda9

Browse files
authored
[TRTLLM-9490][feat] use FlashInfer's top_k_sampling_from_probs (#9457)
Signed-off-by: ixlmar <[email protected]>
1 parent c5f52ab commit 899fda9

File tree

2 files changed

+89
-20
lines changed

2 files changed

+89
-20
lines changed

tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py

Lines changed: 58 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -383,22 +383,17 @@ def __init__(self, top_k: torch.Tensor, top_p: torch.Tensor, temperature: torch.
383383
def from_strategies(
384384
cls, strategies: list[Strategy], cuda_device: torch.device
385385
) -> "_StrategyImpls.TopKTopPSampleOnly":
386-
assert all(strat[0] in ["top_k_top_p", "top_k"] for strat in strategies)
387-
narrowed_strats = cast(list[TopKTopP | TopK], strategies)
388-
top_k_list = []
389-
top_p_list = []
390-
temperature_list = []
391-
for strat in narrowed_strats:
392-
top_k_list.append(strat[1])
393-
if strat[0] == "top_k_top_p":
394-
top_p_list.append(strat[2])
395-
temperature_list.append(strat[3])
396-
else:
397-
top_p_list.append(1.0)
398-
temperature_list.append(strat[2])
399-
top_k = cls._make_tensor(top_k_list, torch.int32, cuda_device)
400-
top_p = cls._make_tensor(top_p_list, torch.float32, cuda_device)
401-
temperature = cls._make_tensor(temperature_list, torch.float32, cuda_device)
386+
assert all(strat[0] == "top_k_top_p" for strat in strategies)
387+
narrowed_strats = cast(list[TopKTopP], strategies)
388+
top_k = cls._make_tensor(
389+
[strat[1] for strat in narrowed_strats], torch.int32, cuda_device
390+
)
391+
top_p = cls._make_tensor(
392+
[strat[2] for strat in narrowed_strats], torch.float32, cuda_device
393+
)
394+
temperature = cls._make_tensor(
395+
[strat[3] for strat in narrowed_strats], torch.float32, cuda_device
396+
)
402397
return cls(top_k, top_p, temperature)
403398

404399
@override
@@ -427,6 +422,50 @@ def sample(
427422
generator=generator,
428423
), None
429424

425+
class TopKSampleOnly(StrategyImplSampleOnly):
426+
def __init__(self, top_k: torch.Tensor, temperature: torch.Tensor):
427+
self._top_k = top_k
428+
self._temperature = temperature
429+
430+
@override
431+
@classmethod
432+
def from_strategies(
433+
cls, strategies: list[Strategy], cuda_device: torch.device
434+
) -> "_StrategyImpls.TopKSampleOnly":
435+
assert all(strat[0] == "top_k" for strat in strategies)
436+
narrowed_strats = cast(list[TopK], strategies)
437+
top_k = cls._make_tensor(
438+
[strat[1] for strat in narrowed_strats], torch.int32, cuda_device
439+
)
440+
temperature = cls._make_tensor(
441+
[strat[2] for strat in narrowed_strats], torch.float32, cuda_device
442+
)
443+
return cls(top_k, temperature)
444+
445+
@override
446+
def sample(
447+
self,
448+
logits: torch.Tensor,
449+
*,
450+
group_logit_indices: Optional[torch.Tensor] = None,
451+
generator: Optional[torch.Generator] = None,
452+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
453+
probs = self._prepare_probs_with_temperature(
454+
logits, group_logit_indices, self._temperature
455+
)
456+
return flashinfer.sampling.top_k_sampling_from_probs(
457+
probs,
458+
top_k=self._top_k,
459+
# NB: Leveraging 'indices' would require applying temperature+softmax before batching,
460+
# because 'flashinfer.sampling.softmax' has no 'indices' argument; but that would
461+
# compute unnecessarily softmax also for situations allowing
462+
# flashinfer.sampling...._sampling_from_logits.
463+
# indices=group_logit_indices,
464+
deterministic=True,
465+
check_nan=self._flashinfer_check_nans(probs),
466+
generator=generator,
467+
), None
468+
430469
class TopPSampleOnly(StrategyImplSampleOnly):
431470
def __init__(self, top_p: torch.Tensor, temperature: torch.Tensor):
432471
self._top_p = top_p
@@ -540,10 +579,9 @@ def strategy_grouping_key(strategy: Strategy, return_probs: bool) -> STRATEGY_KE
540579
match strategy:
541580
case ("top_p", _, _):
542581
return _StrategyImpls.TopPSampleOnly
543-
case ("top_k_top_p", _, _, _) | ("top_k", _, _):
544-
# NB: There is no TopKSampleOnly, because FlashInfer only provides
545-
# top_k_sampling_from_probs (not top_k_sampling_from_logits),
546-
# which is likely slower than top_k_top_p_sampling_from_logits.
582+
case ("top_k", _, _):
583+
return _StrategyImpls.TopKSampleOnly
584+
case ("top_k_top_p", _, _, _):
547585
return _StrategyImpls.TopKTopPSampleOnly
548586
case ("temperature", _):
549587
return _StrategyImpls.TemperatureOnlySampleOnly

tests/unittest/_torch/sampler/test_torch_sampler.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1601,6 +1601,37 @@ def _mock_flashinfer_from_logits(
16011601

16021602
patch_ctx.setattr(flashinfer.sampling, "sampling_from_logits", _mock_flashinfer_from_logits)
16031603

1604+
def _mock_flashinfer_top_k(
1605+
probs: torch.Tensor,
1606+
*,
1607+
top_k: torch.Tensor,
1608+
deterministic: bool,
1609+
check_nan: bool,
1610+
generator: torch.Generator,
1611+
) -> torch.Tensor:
1612+
assert deterministic
1613+
assert not check_nan, "check_nan syncs"
1614+
assert generator is sampler.get_generator(probs.device)
1615+
nonlocal mock_sampling_log
1616+
new_entries = [
1617+
TestBatchedSampling._MockSamplingLogEntry(
1618+
probs=probs[row_idx],
1619+
sampling_params=TestBatchedSampling._TorchUtilsSamplingParams(
1620+
top_k=top_k[row_idx],
1621+
top_p=None,
1622+
temperature=None,
1623+
),
1624+
)
1625+
for row_idx in range(probs.size(0))
1626+
]
1627+
mock_tokens = torch.arange(
1628+
len(mock_sampling_log), len(mock_sampling_log) + len(new_entries)
1629+
)
1630+
mock_sampling_log += new_entries
1631+
return mock_tokens
1632+
1633+
patch_ctx.setattr(flashinfer.sampling, "top_k_sampling_from_probs", _mock_flashinfer_top_k)
1634+
16041635
def _mock_flashinfer_top_p(
16051636
probs: torch.Tensor,
16061637
*,

0 commit comments

Comments
 (0)