@@ -3088,13 +3088,8 @@ def update_requests(
30883088 @nvtx_range ("update_requests_single_beam_single_step" )
30893089 def update_requests_single_beam_single_step (self , state : SampleStateTRTLLM ):
30903090 """Specialization of update_requests for single beam and single step"""
3091- new_tokens_host = state .host .new_tokens .flatten ().tolist ()
30923091 sequence_lengths_host_data = state .host .sequence_lengths .flatten ().tolist ()
30933092 finish_reasons = state .host .finish_reasons .flatten ().tolist ()
3094- log_probs_host_tensor = state .host .log_probs
3095- cum_log_probs_host = (
3096- state .host .cum_log_probs .tolist () if state .host .cum_log_probs is not None else None
3097- )
30983093
30993094 reqs = [
31003095 r for r in state .scheduled_requests .context_requests if not r .is_context_init_state
@@ -3104,44 +3099,53 @@ def update_requests_single_beam_single_step(self, state: SampleStateTRTLLM):
31043099 if not r .is_generation_complete_state
31053100 ]
31063101
3107- reqs_with_new_tokens = [
3108- r for r in reqs if ( sequence_lengths_host_data [ r . py_seq_slot ] > r . get_num_tokens ( 0 ))
3109- ]
3102+ # NB: To ensure good performance, we must
3103+ # 1. Avoid accessing torch.Tensor object inside the for-each-request loops
3104+ # 2. Convert only necessary data to Python list
31103105
31113106 # Add new tokens
3112- new_tokens = [new_tokens_host [r .py_seq_slot ] for r in reqs_with_new_tokens ]
3107+ reqs_with_new_tokens = []
3108+ seq_slots = []
3109+ seq_slots_need_log_probs = []
3110+ for request in reqs :
3111+ if sequence_lengths_host_data [request .py_seq_slot ] <= request .get_num_tokens (0 ):
3112+ continue
3113+
3114+ reqs_with_new_tokens .append (request )
3115+ seq_slots .append (request .py_seq_slot )
3116+
3117+ if request .py_return_log_probs :
3118+ seq_slots_need_log_probs .append (request .py_seq_slot )
3119+
3120+ # [maxTokensPerStep, batchSize, maxBeamWidth]
3121+ new_tokens = state .host .new_tokens [0 , seq_slots , 0 ].tolist ()
31133122 add_new_tokens_to_requests (reqs_with_new_tokens , new_tokens , 0 )
31143123
31153124 # Log probs
3116- if log_probs_host_tensor is not None :
3117- # Log probs
3118- seq_slots = []
3119- seq_lens = []
3120- for request in reqs_with_new_tokens :
3121- if request .py_return_log_probs :
3122- seq_slot = request .py_seq_slot
3123- seq_slots .append (seq_slot )
3124- seq_lens .append (sequence_lengths_host_data [seq_slot ] - 1 )
3125-
3126- log_probs_host = log_probs_host_tensor [seq_slots , 0 , seq_lens ].tolist ()
3127- idx = 0
3128- for request in reqs_with_new_tokens :
3125+ if state .host .log_probs is not None :
3126+ # [batchSize, maxBeamWidth]
3127+ seq_last_idx = state .host .sequence_lengths [seq_slots_need_log_probs , 0 ] - 1
3128+ # [batchSize, maxBeamWidth, maxSequenceLength]
3129+ log_probs_host = state .host .log_probs [
3130+ seq_slots_need_log_probs , 0 , seq_last_idx
3131+ ].tolist ()
3132+ # [batchSize, maxBeamWidth]
3133+ cum_log_probs_host = state .host .cum_log_probs [seq_slots_need_log_probs , 0 ].tolist ()
3134+
3135+ log_probs_idx = 0
3136+ for request , new_token in zip (reqs_with_new_tokens , new_tokens ):
31293137 if request .py_return_log_probs :
31303138 log_probs = [
31313139 {
3132- new_tokens_host [ seq_slot ] : Logprob (
3133- logprob = log_probs_host [idx ],
3140+ new_token : Logprob (
3141+ logprob = log_probs_host [log_probs_idx ],
31343142 rank = 1 ,
31353143 )
31363144 }
31373145 ]
3138- cum_log_probs = [
3139- cum_log_probs_host [seq_slot ][0 ]
3140- if isinstance (cum_log_probs_host [seq_slot ], list )
3141- else cum_log_probs_host [seq_slot ]
3142- ]
3146+ cum_log_probs = [cum_log_probs_host [log_probs_idx ]]
31433147 request .py_result .append_log_probs ([log_probs ], cum_log_probs )
3144- idx += 1
3148+ log_probs_idx += 1
31453149
31463150 for request in reqs :
31473151 request .py_decoding_iter += 1
0 commit comments