@@ -3093,13 +3093,8 @@ def update_requests(
30933093 @nvtx_range ("update_requests_single_beam_single_step" )
30943094 def update_requests_single_beam_single_step (self , state : SampleStateTRTLLM ):
30953095 """Specialization of update_requests for single beam and single step"""
3096- new_tokens_host = state .host .new_tokens .flatten ().tolist ()
30973096 sequence_lengths_host_data = state .host .sequence_lengths .flatten ().tolist ()
30983097 finish_reasons = state .host .finish_reasons .flatten ().tolist ()
3099- log_probs_host_tensor = state .host .log_probs
3100- cum_log_probs_host = (
3101- state .host .cum_log_probs .tolist () if state .host .cum_log_probs is not None else None
3102- )
31033098
31043099 reqs = [
31053100 r for r in state .scheduled_requests .context_requests if not r .is_context_init_state
@@ -3109,44 +3104,53 @@ def update_requests_single_beam_single_step(self, state: SampleStateTRTLLM):
31093104 if not r .is_generation_complete_state
31103105 ]
31113106
3112- reqs_with_new_tokens = [
3113- r for r in reqs if ( sequence_lengths_host_data [ r . py_seq_slot ] > r . get_num_tokens ( 0 ))
3114- ]
3107+ # NB: To ensure good performance, we must
3108+ # 1. Avoid accessing torch.Tensor object inside the for-each-request loops
3109+ # 2. Convert only necessary data to Python list
31153110
31163111 # Add new tokens
3117- new_tokens = [new_tokens_host [r .py_seq_slot ] for r in reqs_with_new_tokens ]
3112+ reqs_with_new_tokens = []
3113+ seq_slots = []
3114+ seq_slots_need_log_probs = []
3115+ for request in reqs :
3116+ if sequence_lengths_host_data [request .py_seq_slot ] <= request .get_num_tokens (0 ):
3117+ continue
3118+
3119+ reqs_with_new_tokens .append (request )
3120+ seq_slots .append (request .py_seq_slot )
3121+
3122+ if request .py_return_log_probs :
3123+ seq_slots_need_log_probs .append (request .py_seq_slot )
3124+
3125+ # [maxTokensPerStep, batchSize, maxBeamWidth]
3126+ new_tokens = state .host .new_tokens [0 , seq_slots , 0 ].tolist ()
31183127 add_new_tokens_to_requests (reqs_with_new_tokens , new_tokens , 0 )
31193128
31203129 # Log probs
3121- if log_probs_host_tensor is not None :
3122- # Log probs
3123- seq_slots = []
3124- seq_lens = []
3125- for request in reqs_with_new_tokens :
3126- if request .py_return_log_probs :
3127- seq_slot = request .py_seq_slot
3128- seq_slots .append (seq_slot )
3129- seq_lens .append (sequence_lengths_host_data [seq_slot ] - 1 )
3130-
3131- log_probs_host = log_probs_host_tensor [seq_slots , 0 , seq_lens ].tolist ()
3132- idx = 0
3133- for request in reqs_with_new_tokens :
3130+ if state .host .log_probs is not None :
3131+ # [batchSize, maxBeamWidth]
3132+ seq_last_idx = state .host .sequence_lengths [seq_slots_need_log_probs , 0 ] - 1
3133+ # [batchSize, maxBeamWidth, maxSequenceLength]
3134+ log_probs_host = state .host .log_probs [
3135+ seq_slots_need_log_probs , 0 , seq_last_idx
3136+ ].tolist ()
3137+ # [batchSize, maxBeamWidth]
3138+ cum_log_probs_host = state .host .cum_log_probs [seq_slots_need_log_probs , 0 ].tolist ()
3139+
3140+ log_probs_idx = 0
3141+ for request , new_token in zip (reqs_with_new_tokens , new_tokens ):
31343142 if request .py_return_log_probs :
31353143 log_probs = [
31363144 {
3137- new_tokens_host [ seq_slot ] : Logprob (
3138- logprob = log_probs_host [idx ],
3145+ new_token : Logprob (
3146+ logprob = log_probs_host [log_probs_idx ],
31393147 rank = 1 ,
31403148 )
31413149 }
31423150 ]
3143- cum_log_probs = [
3144- cum_log_probs_host [seq_slot ][0 ]
3145- if isinstance (cum_log_probs_host [seq_slot ], list )
3146- else cum_log_probs_host [seq_slot ]
3147- ]
3151+ cum_log_probs = [cum_log_probs_host [log_probs_idx ]]
31483152 request .py_result .append_log_probs ([log_probs ], cum_log_probs )
3149- idx += 1
3153+ log_probs_idx += 1
31503154
31513155 for request in reqs :
31523156 request .py_decoding_iter += 1
0 commit comments