Skip to content

Commit f7e2456

Browse files
[TRTLLM-9680][perf] Optimize TRTLLMSampler log_probs performance (Core fix has been merged via #9353) (#9655)
Signed-off-by: Yuan Tong <[email protected]>
1 parent 00c0564 commit f7e2456

File tree

1 file changed

+34
-30
lines changed

1 file changed

+34
-30
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3088,13 +3088,8 @@ def update_requests(
30883088
@nvtx_range("update_requests_single_beam_single_step")
30893089
def update_requests_single_beam_single_step(self, state: SampleStateTRTLLM):
30903090
"""Specialization of update_requests for single beam and single step"""
3091-
new_tokens_host = state.host.new_tokens.flatten().tolist()
30923091
sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist()
30933092
finish_reasons = state.host.finish_reasons.flatten().tolist()
3094-
log_probs_host_tensor = state.host.log_probs
3095-
cum_log_probs_host = (
3096-
state.host.cum_log_probs.tolist() if state.host.cum_log_probs is not None else None
3097-
)
30983093

30993094
reqs = [
31003095
r for r in state.scheduled_requests.context_requests if not r.is_context_init_state
@@ -3104,44 +3099,53 @@ def update_requests_single_beam_single_step(self, state: SampleStateTRTLLM):
31043099
if not r.is_generation_complete_state
31053100
]
31063101

3107-
reqs_with_new_tokens = [
3108-
r for r in reqs if (sequence_lengths_host_data[r.py_seq_slot] > r.get_num_tokens(0))
3109-
]
3102+
# NB: To ensure good performance, we must
3103+
# 1. Avoid accessing torch.Tensor object inside the for-each-request loops
3104+
# 2. Convert only necessary data to Python list
31103105

31113106
# Add new tokens
3112-
new_tokens = [new_tokens_host[r.py_seq_slot] for r in reqs_with_new_tokens]
3107+
reqs_with_new_tokens = []
3108+
seq_slots = []
3109+
seq_slots_need_log_probs = []
3110+
for request in reqs:
3111+
if sequence_lengths_host_data[request.py_seq_slot] <= request.get_num_tokens(0):
3112+
continue
3113+
3114+
reqs_with_new_tokens.append(request)
3115+
seq_slots.append(request.py_seq_slot)
3116+
3117+
if request.py_return_log_probs:
3118+
seq_slots_need_log_probs.append(request.py_seq_slot)
3119+
3120+
# [maxTokensPerStep, batchSize, maxBeamWidth]
3121+
new_tokens = state.host.new_tokens[0, seq_slots, 0].tolist()
31133122
add_new_tokens_to_requests(reqs_with_new_tokens, new_tokens, 0)
31143123

31153124
# Log probs
3116-
if log_probs_host_tensor is not None:
3117-
# Log probs
3118-
seq_slots = []
3119-
seq_lens = []
3120-
for request in reqs_with_new_tokens:
3121-
if request.py_return_log_probs:
3122-
seq_slot = request.py_seq_slot
3123-
seq_slots.append(seq_slot)
3124-
seq_lens.append(sequence_lengths_host_data[seq_slot] - 1)
3125-
3126-
log_probs_host = log_probs_host_tensor[seq_slots, 0, seq_lens].tolist()
3127-
idx = 0
3128-
for request in reqs_with_new_tokens:
3125+
if state.host.log_probs is not None:
3126+
# [batchSize, maxBeamWidth]
3127+
seq_last_idx = state.host.sequence_lengths[seq_slots_need_log_probs, 0] - 1
3128+
# [batchSize, maxBeamWidth, maxSequenceLength]
3129+
log_probs_host = state.host.log_probs[
3130+
seq_slots_need_log_probs, 0, seq_last_idx
3131+
].tolist()
3132+
# [batchSize, maxBeamWidth]
3133+
cum_log_probs_host = state.host.cum_log_probs[seq_slots_need_log_probs, 0].tolist()
3134+
3135+
log_probs_idx = 0
3136+
for request, new_token in zip(reqs_with_new_tokens, new_tokens):
31293137
if request.py_return_log_probs:
31303138
log_probs = [
31313139
{
3132-
new_tokens_host[seq_slot]: Logprob(
3133-
logprob=log_probs_host[idx],
3140+
new_token: Logprob(
3141+
logprob=log_probs_host[log_probs_idx],
31343142
rank=1,
31353143
)
31363144
}
31373145
]
3138-
cum_log_probs = [
3139-
cum_log_probs_host[seq_slot][0]
3140-
if isinstance(cum_log_probs_host[seq_slot], list)
3141-
else cum_log_probs_host[seq_slot]
3142-
]
3146+
cum_log_probs = [cum_log_probs_host[log_probs_idx]]
31433147
request.py_result.append_log_probs([log_probs], cum_log_probs)
3144-
idx += 1
3148+
log_probs_idx += 1
31453149

31463150
for request in reqs:
31473151
request.py_decoding_iter += 1

0 commit comments

Comments
 (0)