@@ -383,22 +383,17 @@ def __init__(self, top_k: torch.Tensor, top_p: torch.Tensor, temperature: torch.
383383 def from_strategies (
384384 cls , strategies : list [Strategy ], cuda_device : torch .device
385385 ) -> "_StrategyImpls.TopKTopPSampleOnly" :
386- assert all (strat [0 ] in ["top_k_top_p" , "top_k" ] for strat in strategies )
387- narrowed_strats = cast (list [TopKTopP | TopK ], strategies )
388- top_k_list = []
389- top_p_list = []
390- temperature_list = []
391- for strat in narrowed_strats :
392- top_k_list .append (strat [1 ])
393- if strat [0 ] == "top_k_top_p" :
394- top_p_list .append (strat [2 ])
395- temperature_list .append (strat [3 ])
396- else :
397- top_p_list .append (1.0 )
398- temperature_list .append (strat [2 ])
399- top_k = cls ._make_tensor (top_k_list , torch .int32 , cuda_device )
400- top_p = cls ._make_tensor (top_p_list , torch .float32 , cuda_device )
401- temperature = cls ._make_tensor (temperature_list , torch .float32 , cuda_device )
386+ assert all (strat [0 ] == "top_k_top_p" for strat in strategies )
387+ narrowed_strats = cast (list [TopKTopP ], strategies )
388+ top_k = cls ._make_tensor (
389+ [strat [1 ] for strat in narrowed_strats ], torch .int32 , cuda_device
390+ )
391+ top_p = cls ._make_tensor (
392+ [strat [2 ] for strat in narrowed_strats ], torch .float32 , cuda_device
393+ )
394+ temperature = cls ._make_tensor (
395+ [strat [3 ] for strat in narrowed_strats ], torch .float32 , cuda_device
396+ )
402397 return cls (top_k , top_p , temperature )
403398
404399 @override
@@ -427,6 +422,50 @@ def sample(
427422 generator = generator ,
428423 ), None
429424
425+ class TopKSampleOnly (StrategyImplSampleOnly ):
426+ def __init__ (self , top_k : torch .Tensor , temperature : torch .Tensor ):
427+ self ._top_k = top_k
428+ self ._temperature = temperature
429+
430+ @override
431+ @classmethod
432+ def from_strategies (
433+ cls , strategies : list [Strategy ], cuda_device : torch .device
434+ ) -> "_StrategyImpls.TopKSampleOnly" :
435+ assert all (strat [0 ] == "top_k" for strat in strategies )
436+ narrowed_strats = cast (list [TopK ], strategies )
437+ top_k = cls ._make_tensor (
438+ [strat [1 ] for strat in narrowed_strats ], torch .int32 , cuda_device
439+ )
440+ temperature = cls ._make_tensor (
441+ [strat [2 ] for strat in narrowed_strats ], torch .float32 , cuda_device
442+ )
443+ return cls (top_k , temperature )
444+
445+ @override
446+ def sample (
447+ self ,
448+ logits : torch .Tensor ,
449+ * ,
450+ group_logit_indices : Optional [torch .Tensor ] = None ,
451+ generator : Optional [torch .Generator ] = None ,
452+ ) -> tuple [torch .Tensor , Optional [torch .Tensor ]]:
453+ probs = self ._prepare_probs_with_temperature (
454+ logits , group_logit_indices , self ._temperature
455+ )
456+ return flashinfer .sampling .top_k_sampling_from_probs (
457+ probs ,
458+ top_k = self ._top_k ,
459+ # NB: Leveraging 'indices' would require applying temperature+softmax before batching,
460+ # because 'flashinfer.sampling.softmax' has no 'indices' argument; but that would
461+ # compute unnecessarily softmax also for situations allowing
462+ # flashinfer.sampling...._sampling_from_logits.
463+ # indices=group_logit_indices,
464+ deterministic = True ,
465+ check_nan = self ._flashinfer_check_nans (probs ),
466+ generator = generator ,
467+ ), None
468+
430469 class TopPSampleOnly (StrategyImplSampleOnly ):
431470 def __init__ (self , top_p : torch .Tensor , temperature : torch .Tensor ):
432471 self ._top_p = top_p
@@ -540,10 +579,9 @@ def strategy_grouping_key(strategy: Strategy, return_probs: bool) -> STRATEGY_KE
540579 match strategy :
541580 case ("top_p" , _, _):
542581 return _StrategyImpls .TopPSampleOnly
543- case ("top_k_top_p" , _, _, _) | ("top_k" , _, _):
544- # NB: There is no TopKSampleOnly, because FlashInfer only provides
545- # top_k_sampling_from_probs (not top_k_sampling_from_logits),
546- # which is likely slower than top_k_top_p_sampling_from_logits.
582+ case ("top_k" , _, _):
583+ return _StrategyImpls .TopKSampleOnly
584+ case ("top_k_top_p" , _, _, _):
547585 return _StrategyImpls .TopKTopPSampleOnly
548586 case ("temperature" , _):
549587 return _StrategyImpls .TemperatureOnlySampleOnly
0 commit comments