File tree Expand file tree Collapse file tree 3 files changed +34
-4
lines changed
tests/unittest/_torch/sampler Expand file tree Collapse file tree 3 files changed +34
-4
lines changed Original file line number Diff line number Diff line change @@ -887,6 +887,8 @@ def _create_store(self) -> Store:
887887 first_finish_reasons = int_tensor (
888888 self .CACHE_INDIRECTION_SHAPE [:- 1 ],
889889 ),
890+ max_lengths_tensor = int_tensor (self .max_num_sequences ),
891+ end_ids = int_tensor (self .max_num_sequences ),
890892 )
891893 else :
892894 return self .Store (
@@ -1330,15 +1332,16 @@ def _is_new_request(self, request: LlmRequest) -> bool:
13301332 or request .is_disagg_generation_transmission_complete
13311333 )
13321334
1333- def setup_sampler_step (self , requests : ScheduledRequests ):
1335+ @override
1336+ def setup_sampler_step (self , scheduled_requests : ScheduledRequests ):
13341337 """Setup the sampler step for the requests
13351338
13361339 Args:
13371340 requests: list[LlmRequest]. The requests to setup the sampler step for
13381341 """
13391342 if self ._use_beam_search :
1340- self ._prepare_beam_search (requests .all_requests ())
1341- for request in requests .all_requests ():
1343+ self ._prepare_beam_search (scheduled_requests .all_requests ())
1344+ for request in scheduled_requests .all_requests ():
13421345 if self ._is_new_request (request ):
13431346 self .store .max_lengths_tensor [request .py_seq_slot ].fill_ (
13441347 min (self .max_seq_len , request .orig_prompt_len + request .py_max_new_tokens )
Original file line number Diff line number Diff line change @@ -222,7 +222,10 @@ class Store(TorchSampler.Store):
222222 next_draft_tokens : torch .Tensor
223223 new_tokens_lens : torch .Tensor
224224 max_total_draft_tokens : torch .Tensor
225- finish_reasons : None = None # Necessary to satisfy the interface of TorchSampler.Store
225+ # Necessary to satisfy the interface of TorchSampler.Store
226+ finish_reasons : None = None
227+ end_ids : None = None
228+ max_lengths_tensor : None = None
226229
227230 def __post_init__ (self ):
228231 pass # finish_reasons has no size to compare against new_tokens in MTPSampler
Original file line number Diff line number Diff line change @@ -691,16 +691,40 @@ def setup(requests: list["RequestCase"]):
691691 seq_slots = torch .tensor (
692692 [req .request .py_seq_slot for req in requests ], device = "cuda" , dtype = torch .int64
693693 )
694+ seq_lens = torch .tensor (
695+ [req .request .max_beam_num_tokens for req in requests ], dtype = torch .int32 , device = "cuda"
696+ )
694697 new_tokens = torch .tensor (
695698 [req .new_tokens for req in requests ], dtype = torch .int32 , device = "cuda"
696699 ).T
697700 sampler .store .new_tokens [:, seq_slots , BEAM ] = new_tokens
701+ max_seq_lens = torch .tensor (
702+ [
703+ min (
704+ sampler .max_seq_len , req .request .orig_prompt_len + req .request .py_max_new_tokens
705+ )
706+ for req in requests
707+ ],
708+ dtype = torch .int32 ,
709+ device = "cuda" ,
710+ )
711+ end_ids = torch .tensor (
712+ [
713+ req .request .py_end_id if req .request .py_end_id is not None else - 1
714+ for req in requests
715+ ],
716+ dtype = torch .int32 ,
717+ device = "cuda" ,
718+ )
719+ sampler .store .max_lengths_tensor [seq_slots ] = max_seq_lens
720+ sampler .store .end_ids [seq_slots ] = end_ids
698721
699722 def run ():
700723 sampler ._write_finish_reasons (
701724 [req .request for req in requests ],
702725 finish_reasons = sampler .store .finish_reasons ,
703726 new_tokens = sampler .store .new_tokens ,
727+ seq_lens = seq_lens ,
704728 seq_slots = seq_slots ,
705729 )
706730
You can’t perform that action at this time.
0 commit comments