Skip to content

Commit e84214f

Browse files
tongyuantongyucodego7250
authored andcommitted
[TRTLLM-9680][perf] Optimize TRTLLMSampler log_probs performance (Core fix has been merged via NVIDIA#9353) (NVIDIA#9655)
Signed-off-by: Yuan Tong <[email protected]>
1 parent f67076b commit e84214f

File tree

1 file changed

+34
-30
lines changed

1 file changed

+34
-30
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3093,13 +3093,8 @@ def update_requests(
30933093
@nvtx_range("update_requests_single_beam_single_step")
30943094
def update_requests_single_beam_single_step(self, state: SampleStateTRTLLM):
30953095
"""Specialization of update_requests for single beam and single step"""
3096-
new_tokens_host = state.host.new_tokens.flatten().tolist()
30973096
sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist()
30983097
finish_reasons = state.host.finish_reasons.flatten().tolist()
3099-
log_probs_host_tensor = state.host.log_probs
3100-
cum_log_probs_host = (
3101-
state.host.cum_log_probs.tolist() if state.host.cum_log_probs is not None else None
3102-
)
31033098

31043099
reqs = [
31053100
r for r in state.scheduled_requests.context_requests if not r.is_context_init_state
@@ -3109,44 +3104,53 @@ def update_requests_single_beam_single_step(self, state: SampleStateTRTLLM):
31093104
if not r.is_generation_complete_state
31103105
]
31113106

3112-
reqs_with_new_tokens = [
3113-
r for r in reqs if (sequence_lengths_host_data[r.py_seq_slot] > r.get_num_tokens(0))
3114-
]
3107+
# NB: To ensure good performance, we must
3108+
# 1. Avoid accessing torch.Tensor object inside the for-each-request loops
3109+
# 2. Convert only necessary data to Python list
31153110

31163111
# Add new tokens
3117-
new_tokens = [new_tokens_host[r.py_seq_slot] for r in reqs_with_new_tokens]
3112+
reqs_with_new_tokens = []
3113+
seq_slots = []
3114+
seq_slots_need_log_probs = []
3115+
for request in reqs:
3116+
if sequence_lengths_host_data[request.py_seq_slot] <= request.get_num_tokens(0):
3117+
continue
3118+
3119+
reqs_with_new_tokens.append(request)
3120+
seq_slots.append(request.py_seq_slot)
3121+
3122+
if request.py_return_log_probs:
3123+
seq_slots_need_log_probs.append(request.py_seq_slot)
3124+
3125+
# [maxTokensPerStep, batchSize, maxBeamWidth]
3126+
new_tokens = state.host.new_tokens[0, seq_slots, 0].tolist()
31183127
add_new_tokens_to_requests(reqs_with_new_tokens, new_tokens, 0)
31193128

31203129
# Log probs
3121-
if log_probs_host_tensor is not None:
3122-
# Log probs
3123-
seq_slots = []
3124-
seq_lens = []
3125-
for request in reqs_with_new_tokens:
3126-
if request.py_return_log_probs:
3127-
seq_slot = request.py_seq_slot
3128-
seq_slots.append(seq_slot)
3129-
seq_lens.append(sequence_lengths_host_data[seq_slot] - 1)
3130-
3131-
log_probs_host = log_probs_host_tensor[seq_slots, 0, seq_lens].tolist()
3132-
idx = 0
3133-
for request in reqs_with_new_tokens:
3130+
if state.host.log_probs is not None:
3131+
# [batchSize, maxBeamWidth]
3132+
seq_last_idx = state.host.sequence_lengths[seq_slots_need_log_probs, 0] - 1
3133+
# [batchSize, maxBeamWidth, maxSequenceLength]
3134+
log_probs_host = state.host.log_probs[
3135+
seq_slots_need_log_probs, 0, seq_last_idx
3136+
].tolist()
3137+
# [batchSize, maxBeamWidth]
3138+
cum_log_probs_host = state.host.cum_log_probs[seq_slots_need_log_probs, 0].tolist()
3139+
3140+
log_probs_idx = 0
3141+
for request, new_token in zip(reqs_with_new_tokens, new_tokens):
31343142
if request.py_return_log_probs:
31353143
log_probs = [
31363144
{
3137-
new_tokens_host[seq_slot]: Logprob(
3138-
logprob=log_probs_host[idx],
3145+
new_token: Logprob(
3146+
logprob=log_probs_host[log_probs_idx],
31393147
rank=1,
31403148
)
31413149
}
31423150
]
3143-
cum_log_probs = [
3144-
cum_log_probs_host[seq_slot][0]
3145-
if isinstance(cum_log_probs_host[seq_slot], list)
3146-
else cum_log_probs_host[seq_slot]
3147-
]
3151+
cum_log_probs = [cum_log_probs_host[log_probs_idx]]
31483152
request.py_result.append_log_probs([log_probs], cum_log_probs)
3149-
idx += 1
3153+
log_probs_idx += 1
31503154

31513155
for request in reqs:
31523156
request.py_decoding_iter += 1

0 commit comments

Comments
 (0)