1515import enum
1616import sys
1717from abc import ABC , abstractmethod
18- from collections import defaultdict
18+ from collections import defaultdict , namedtuple
1919from collections .abc import Iterable
2020from concurrent import futures
2121from dataclasses import dataclass
@@ -99,11 +99,14 @@ class LogProbsState:
9999
100100@dataclass (kw_only = True )
101101class LogProbsStateList :
102- sampled_vals : list [list [list [float ]]]
103- sampled_indices : list [list [list [int ]]]
104- sampled_rank : list [list [list [int ]]]
105- topk_vals : list [list [list [float ]]]
106- topk_indices : list [list [list [int ]]]
102+ FloatState = list [list [list [float ]]]
103+ IntState = list [list [list [float ]]]
104+
105+ sampled_vals : FloatState
106+ sampled_indices : IntState
107+ sampled_rank : IntState
108+ topk_vals : FloatState
109+ topk_indices : IntState
107110
108111 @staticmethod
109112 def from_logprobs_state (logprobs_state : LogProbsState ) -> "LogProbsStateList" :
@@ -241,7 +244,7 @@ class SampleStateWithMMResult:
241244 data : MultimodalResult
242245
243246
244- @dataclass (kw_only = True , frozen = True )
247+ @dataclass (kw_only = True , frozen = True , slots = True )
245248class RequestGroupKey (Generic [GenericStrategyKeyType ]):
246249 strategy_key : GenericStrategyKeyType
247250 needs_probs : bool
@@ -420,10 +423,20 @@ def _group_requests_by_strategy_key(
420423 vocab_size : int ,
421424) -> dict [RequestGroupKey [GenericStrategyKeyType ], RequestGroupValue ]:
422425 # NB: Client code relies on request indices in returned torch.Tensor being sorted.
423- group_dict : dict [
424- tuple [GenericStrategyKeyType , bool ],
425- tuple [list [int ], list [Strategy ], list [int ], list [bool ], list [bool ]],
426- ] = defaultdict (lambda : ([], [], [], [], []))
426+ RequestGroupValueBuilder = namedtuple (
427+ "RequestGroupValueBuilder" ,
428+ [
429+ "indices" ,
430+ "strategies" ,
431+ "speculation_needs_probs_list" ,
432+ "need_processed_logprobs_list" ,
433+ "need_raw_logprobs_list" ,
434+ ],
435+ )
436+
437+ group_dict : dict [RequestGroupKey , RequestGroupValueBuilder ] = defaultdict (
438+ lambda : RequestGroupValueBuilder ([], [], [], [], [])
439+ )
427440
428441 for req_index , req in enumerate (requests ):
429442 strategy = _request_strategy (req , vocab_size = vocab_size )
@@ -438,37 +451,30 @@ def _group_requests_by_strategy_key(
438451 need_raw_logprobs = req .py_logprobs_mode == LogprobMode .RAW and req .return_log_probs
439452 needs_probs = speculation_needs_probs or need_processed_logprobs
440453 strategy_key = strategy_to_key (strategy , needs_probs )
441- group_dict_entry = group_dict [(strategy_key , needs_probs )]
442- group_dict_entry [0 ].append (req_index )
443- group_dict_entry [1 ].append (strategy )
454+ group_dict_entry = group_dict [
455+ RequestGroupKey (strategy_key = strategy_key , needs_probs = needs_probs )
456+ ]
457+ group_dict_entry .indices .append (req_index )
458+ group_dict_entry .strategies .append (strategy )
444459 if speculation_needs_probs :
445- group_dict_entry [ 2 ] .append (req_index )
446- group_dict_entry [ 3 ] .append (need_processed_logprobs )
447- group_dict_entry [ 4 ] .append (need_raw_logprobs )
460+ group_dict_entry . speculation_needs_probs_list .append (req_index )
461+ group_dict_entry . need_processed_logprobs_list .append (need_processed_logprobs )
462+ group_dict_entry . need_raw_logprobs_list .append (need_raw_logprobs )
448463 return {
449- RequestGroupKey (
450- strategy_key = group_key [0 ],
451- needs_probs = group_key [1 ],
452- ): RequestGroupValue (
453- indices = torch .tensor (indices , pin_memory = pin_memory , dtype = torch .int32 ),
454- strategies = strategies ,
464+ group_key : RequestGroupValue (
465+ indices = torch .tensor (group_value .indices , pin_memory = pin_memory , dtype = torch .int32 ),
466+ strategies = group_value .strategies ,
455467 speculation_needs_probs_indices = torch .tensor (
456- speculation_needs_probs_list , pin_memory = pin_memory , dtype = torch .int32
468+ group_value . speculation_needs_probs_list , pin_memory = pin_memory , dtype = torch .int32
457469 ),
458470 need_processed_logprobs = torch .tensor (
459- need_processed_logprobs_list , pin_memory = pin_memory , dtype = torch .bool
471+ group_value . need_processed_logprobs_list , pin_memory = pin_memory , dtype = torch .bool
460472 ),
461473 need_raw_logprobs = torch .tensor (
462- need_raw_logprobs_list , pin_memory = pin_memory , dtype = torch .bool
474+ group_value . need_raw_logprobs_list , pin_memory = pin_memory , dtype = torch .bool
463475 ),
464476 )
465- for group_key , (
466- indices ,
467- strategies ,
468- speculation_needs_probs_list ,
469- need_processed_logprobs_list ,
470- need_raw_logprobs_list ,
471- ) in group_dict .items ()
477+ for group_key , group_value in group_dict .items ()
472478 }
473479
474480
@@ -967,8 +973,8 @@ def _create_store(self) -> Store:
967973 )
968974 sampled_log_prob_ranks = torch .empty (self .LOGPROBS_SHAPE , device = "cuda" , dtype = torch .int32 )
969975 # These are 0 sized tensors, if topk-logprobs are not used
970- topk_indices = torch .empty (self .topk_logprobs_shape , device = "cuda" , dtype = torch .int32 )
971- topk_vals = torch .empty (self .topk_logprobs_shape , device = "cuda" , dtype = torch .float32 )
976+ topk_indices = torch .empty (self .TOPK_LOGPROBS_SHAPE , device = "cuda" , dtype = torch .int32 )
977+ topk_vals = torch .empty (self .TOPK_LOGPROBS_SHAPE , device = "cuda" , dtype = torch .float32 )
972978
973979 # Only used for beam search
974980 cache_indirection : torch .Tensor | None = None
@@ -1034,7 +1040,7 @@ def __init__(self, args: Args):
10341040 self .max_seq_len + (0 if args .disable_overlap_scheduler else 1 ),
10351041 )
10361042 self .LOGPROBS_SHAPE = (self .max_num_sequences , self .max_beam_width , self .max_tokens )
1037- self .topk_logprobs_shape = (self .max_num_sequences , self .max_tokens , self .max_topk_logprobs )
1043+ self .TOPK_LOGPROBS_SHAPE = (self .max_num_sequences , self .max_tokens , self .max_topk_logprobs )
10381044 # AutoDeploy build creates the sampler in inference mode,
10391045 # which would disallow in-place mutating of new_tokens.
10401046 # So, we temporarily exit inference mode.
@@ -2037,13 +2043,13 @@ def _prepare_log_probs(self, requests: list[LlmRequest]) -> None:
20372043 )
20382044 if self .max_topk_logprobs < self .batch_max_topk_logprobs :
20392045 self .max_topk_logprobs = self .batch_max_topk_logprobs
2040- self .topk_logprobs_shape = (
2046+ self .TOPK_LOGPROBS_SHAPE = (
20412047 self .max_num_sequences ,
20422048 self .max_tokens ,
20432049 self .max_topk_logprobs ,
20442050 )
2045- self .store .topk_vals .resize_ (self .topk_logprobs_shape )
2046- self .store .topk_indices .resize_ (self .topk_logprobs_shape )
2051+ self .store .topk_vals .resize_ (self .TOPK_LOGPROBS_SHAPE )
2052+ self .store .topk_indices .resize_ (self .TOPK_LOGPROBS_SHAPE )
20472053
20482054 @override
20492055 @torch .inference_mode ()
0 commit comments