@@ -473,10 +473,12 @@ class SampleStateTensorsHostTRTLLM(SampleStateTensors):
473473 finish_reasons : torch .Tensor
474474 sequence_lengths : torch .Tensor
475475 cum_log_probs : torch .Tensor | None = None
476+ gathered_ids : torch .Tensor | None = None
476477
477478
478479@dataclass (kw_only = True )
479480class SampleStateTRTLLM (SampleState ):
481+ finalize_events : dict [str , CudaEvent ]
480482 host : SampleStateTensorsHostTRTLLM
481483
482484
@@ -672,6 +674,24 @@ def sample_async(self, scheduled_requests: ScheduledRequests,
672674 self .store ["decoder_state" ],
673675 self .store ["decoding_input" ][self .micro_batch_idx ])
674676
677+ finalize_events = {}
678+ gathered_ids = None
679+ if beam_width > 1 :
680+ finished_sum_device = self .store ["decoder_state" ].finished_sum
681+
682+ for request in scheduled_requests .all_requests ():
683+ if request .is_context_init_state :
684+ continue
685+ if finished_sum_device [request .seq_slot ] == beam_width :
686+ finalize_events [
687+ request .request_id ] = self ._finalize_request (
688+ request , False )
689+ elif request .streaming :
690+ finalize_events [
691+ request .request_id ] = self ._finalize_request (
692+ request , True )
693+ gathered_ids = self .store ["decoder_state" ].gathered_ids .to (
694+ 'cpu' , non_blocking = True )
675695 new_output_tokens = self .store ["decoder_state" ].all_new_tokens .to (
676696 'cpu' , non_blocking = True )
677697 finished_sum = self .store ["decoder_state" ].finished_sum .to (
@@ -698,7 +718,8 @@ def sample_async(self, scheduled_requests: ScheduledRequests,
698718 finish_reasons = finish_reasons ,
699719 sequence_lengths = sequence_lengths ,
700720 log_probs = log_probs ,
701- cum_log_probs = cum_log_probs )
721+ cum_log_probs = cum_log_probs ,
722+ gathered_ids = gathered_ids )
702723
703724 sampler_event = torch .cuda .Event ()
704725 sampler_event .record ()
@@ -709,7 +730,8 @@ def sample_async(self, scheduled_requests: ScheduledRequests,
709730 return SampleStateTRTLLM (scheduled_requests = scheduled_requests ,
710731 device = device ,
711732 host = host ,
712- sampler_event = sampler_event )
733+ sampler_event = sampler_event ,
734+ finalize_events = finalize_events )
713735
714736 @torch .inference_mode ()
715737 def update_requests (self , state : SampleStateTRTLLM ):
@@ -797,7 +819,7 @@ def update_requests_multiple_beams_or_drafting(self,
797819 ) if state .host .cum_log_probs is not None else None
798820 log_probs_host = state .host .log_probs .tolist (
799821 ) if state .host .log_probs is not None else None
800- finalize_events = {}
822+ finalize_events = state . finalize_events
801823
802824 reqs = [
803825 r for r in state .scheduled_requests .context_requests
@@ -865,19 +887,9 @@ def update_requests_multiple_beams_or_drafting(self,
865887
866888 if finished_sum_host [seq_slot ] == beam_width :
867889 request .state = LlmRequestState .GENERATION_COMPLETE
868- if beam_width > 1 :
869- finalize_events [
870- request .request_id ] = self ._finalize_request (
871- request , False )
872- elif request .streaming and beam_width > 1 :
873- finalize_events [request .request_id ] = self ._finalize_request (
874- request , True )
875- # post process all requests if necessary
876- if beam_width > 1 :
877- for request in reqs :
878- if request .request_id in finalize_events :
879- self ._post_process_request (
880- request , finalize_events [request .request_id ])
890+ for request in reqs :
891+ if request .request_id in finalize_events :
892+ self ._post_process_request (request , state )
881893
882894 def _finalize_request (self , request : LlmRequest , streaming : bool ):
883895 """ Finalizes the request. This is necessary for beam search. """
@@ -888,25 +900,24 @@ def _finalize_request(self, request: LlmRequest, streaming: bool):
888900 return event
889901
890902 def _post_process_request (self , request : LlmRequest ,
891- finalize_event : CudaEvent ):
903+ state : SampleStateTRTLLM ):
892904 """ Post Process the request. Updates the sequence according to the beam search results.
893905 request: LlmRequest which shall be post processed
894906 finalize_event: CudaEvent to wait for the finalize step to finish
895907 """
896908 seq_slot = request .py_seq_slot
897909 beam_width = request .sampling_config .beam_width
898910 # synchronize on the finalize event before continuing the post processing.
899- finalize_event .synchronize ()
911+ # should be unnecessary, as already wait for the sampler event in update_requests
912+ state .finalize_events [request .request_id ].synchronize ()
900913
901914 # Get these values again, as they might have changed during the finalize step
902- output_ids_host = self .store ["decoder_state" ].gathered_ids .to ('cpu' )
903- sequence_lengths_host = self .store ["decoder_state" ].sequence_lengths .to (
904- 'cpu' )
915+ output_ids_host = state .host .gathered_ids
916+ sequence_lengths_host = state .host .sequence_lengths
905917
906918 if request .py_return_log_probs :
907- log_probs_host = self .store ["decoder_state" ].log_probs .to ('cpu' )
908- cum_log_probs_host = self .store ["decoder_state" ].cum_log_probs .to (
909- 'cpu' )
919+ log_probs_host = state .host .log_probs
920+ cum_log_probs_host = state .host .cum_log_probs
910921
911922 generated_tokens = [[0 ]] * beam_width
912923 log_probs = [[] for _ in range (beam_width )]
0 commit comments