File tree Expand file tree Collapse file tree 3 files changed +32
-4
lines changed
tests/unittest/_torch/sampler Expand file tree Collapse file tree 3 files changed +32
-4
lines changed Original file line number Diff line number Diff line change @@ -1541,15 +1541,16 @@ def _is_new_request(self, request: LlmRequest) -> bool:
15411541 or request .is_disagg_generation_transmission_complete
15421542 )
15431543
1544- def setup_sampler_step (self , requests : ScheduledRequests ):
1544+ @override
1545+ def setup_sampler_step (self , scheduled_requests : ScheduledRequests ):
15451546 """Setup the sampler step for the requests
15461547
15471548 Args:
15481549 requests: list[LlmRequest]. The requests to setup the sampler step for
15491550 """
15501551 if self ._use_beam_search :
1551- self ._prepare_beam_search (requests .all_requests ())
1552- for request in requests .all_requests ():
1552+ self ._prepare_beam_search (scheduled_requests .all_requests ())
1553+ for request in scheduled_requests .all_requests ():
15531554 if self ._is_new_request (request ):
15541555 self .store .max_lengths_tensor [request .py_seq_slot ].fill_ (
15551556 min (self .max_seq_len , request .orig_prompt_len + request .py_max_new_tokens )
Original file line number Diff line number Diff line change @@ -222,7 +222,10 @@ class Store(TorchSampler.Store):
222222 next_draft_tokens : torch .Tensor
223223 new_tokens_lens : torch .Tensor
224224 max_total_draft_tokens : torch .Tensor
225- finish_reasons : None = None # Necessary to satisfy the interface of TorchSampler.Store
225+ # Necessary to satisfy the interface of TorchSampler.Store
226+ finish_reasons : None = None
227+ end_ids : None = None
228+ max_lengths_tensor : None = None
226229
227230 def __post_init__ (self ):
228231 pass # finish_reasons has no size to compare against new_tokens in MTPSampler
Original file line number Diff line number Diff line change @@ -701,16 +701,40 @@ def setup(requests: list["RequestCase"]):
701701 seq_slots = torch .tensor (
702702 [req .request .py_seq_slot for req in requests ], device = "cuda" , dtype = torch .int64
703703 )
704+ seq_lens = torch .tensor (
705+ [req .request .max_beam_num_tokens for req in requests ], dtype = torch .int32 , device = "cuda"
706+ )
704707 new_tokens = torch .tensor (
705708 [req .new_tokens for req in requests ], dtype = torch .int32 , device = "cuda"
706709 ).T
707710 sampler .store .new_tokens [:, seq_slots , BEAM ] = new_tokens
711+ max_seq_lens = torch .tensor (
712+ [
713+ min (
714+ sampler .max_seq_len , req .request .orig_prompt_len + req .request .py_max_new_tokens
715+ )
716+ for req in requests
717+ ],
718+ dtype = torch .int32 ,
719+ device = "cuda" ,
720+ )
721+ end_ids = torch .tensor (
722+ [
723+ req .request .py_end_id if req .request .py_end_id is not None else - 1
724+ for req in requests
725+ ],
726+ dtype = torch .int32 ,
727+ device = "cuda" ,
728+ )
729+ sampler .store .max_lengths_tensor [seq_slots ] = max_seq_lens
730+ sampler .store .end_ids [seq_slots ] = end_ids
708731
709732 def run ():
710733 sampler ._write_finish_reasons (
711734 [req .request for req in requests ],
712735 finish_reasons = sampler .store .finish_reasons ,
713736 new_tokens = sampler .store .new_tokens ,
737+ seq_lens = seq_lens ,
714738 seq_slots = seq_slots ,
715739 )
716740
You can’t perform that action at this time.
0 commit comments