Skip to content

Commit e6187d8

Browse files
authored
[https://nvbugs/5708810][fix] Fix TRTLLMSampler (#9710)
Signed-off-by: Michal Guzek <[email protected]>
1 parent 9ba1426 commit e6187d8

File tree

3 files changed

+67
-2
lines changed

3 files changed

+67
-2
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3135,7 +3135,11 @@ def update_requests_single_beam_single_step(self, state: SampleStateTRTLLM):
31353135
)
31363136
}
31373137
]
3138-
cum_log_probs = [cum_log_probs_host[seq_slot]]
3138+
cum_log_probs = [
3139+
cum_log_probs_host[seq_slot][0]
3140+
if isinstance(cum_log_probs_host[seq_slot], list)
3141+
else cum_log_probs_host[seq_slot]
3142+
]
31393143
request.py_result.append_log_probs([log_probs], cum_log_probs)
31403144
idx += 1
31413145

tensorrt_llm/executor/result.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,14 @@ def _handle_sequence(self,
319319
if response_tensors.request_perf_metrics is not None:
320320
output.request_perf_metrics = response_tensors.request_perf_metrics
321321

322-
if self._done:
322+
# Check if this specific sequence is finished (not just if the entire request is done)
323+
# This is important for best_of > n sampling where sequences finish at different times
324+
sequence_is_finished = (finish_reasons and finish_reasons[src_idx]
325+
!= tllm.FinishReason.NOT_FINISHED
326+
and finish_reasons[src_idx]
327+
!= tllm.FinishReason.CANCELLED) or self._done
328+
329+
if sequence_is_finished:
323330
if finish_reasons[src_idx] == tllm.FinishReason.END_ID:
324331
output.finish_reason = 'stop'
325332
elif finish_reasons[src_idx] == tllm.FinishReason.STOP_WORDS:
@@ -344,6 +351,9 @@ def _handle_sequence(self,
344351
else:
345352
raise ValueError(
346353
f"Unknown finish reason: {finish_reasons[src_idx]}")
354+
355+
# Only record stats and do tracing when the entire request is done
356+
if self._done:
347357
self.record_stats(output, req_perf_metrics_dict)
348358
self.do_tracing(output, req_perf_metrics_dict)
349359

tests/unittest/_torch/sampler/test_trtllm_sampler.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,54 @@ def test_torch_sampler_with_multi_token_stop_words(model_path):
146146

147147
assert len(text) > 0, "Should generate some text"
148148
assert stop_string not in text, f"Stop string '{repr(stop_string)}' should not appear in the output"
149+
150+
151+
@pytest.mark.high_cuda_memory
152+
def test_trtllm_sampler_best_of_with_logprobs(model_path):
153+
"""Test TRTLLMSampler with best_of > n and logprobs."""
154+
155+
llm = create_llm(model_path)
156+
157+
prompt = "The capital of France is"
158+
159+
sampling_config = SamplingParams(
160+
max_tokens=10,
161+
temperature=1.0,
162+
top_k=2,
163+
n=2, # Return 2 sequences
164+
best_of=3, # Generate 3 candidates, pick best 2
165+
logprobs=1 # Return log probabilities
166+
)
167+
168+
outputs = llm.generate([prompt], sampling_params=sampling_config)
169+
170+
llm.shutdown()
171+
172+
assert len(outputs) == 1, "Should return one request output"
173+
174+
request_output = outputs[0]
175+
completion_outputs = request_output.outputs
176+
177+
assert len(
178+
completion_outputs
179+
) == 2, f"Expected 2 outputs (n=2), got {len(completion_outputs)}"
180+
181+
for i, output in enumerate(completion_outputs):
182+
assert len(output.text) > 0, f"Output {i} should have generated text"
183+
184+
assert output.finish_reason is not None, \
185+
f"Output {i} must have a finish_reason"
186+
187+
assert output.cumulative_logprob is not None, \
188+
f"Output {i} should have cumulative_logprob when logprobs is requested"
189+
assert isinstance(output.cumulative_logprob, (float, int)), \
190+
f"Output {i} cumulative_logprob should be a number, got {type(output.cumulative_logprob)}"
191+
192+
assert output.logprobs is not None, \
193+
f"Output {i} should have logprobs when logprobs=1"
194+
assert len(output.logprobs) == len(output.token_ids), \
195+
f"Output {i} should have logprobs for each token"
196+
197+
if len(completion_outputs) >= 2:
198+
assert completion_outputs[0].cumulative_logprob >= completion_outputs[1].cumulative_logprob, \
199+
"Outputs should be sorted by cumulative log probability (best first)"

0 commit comments

Comments
 (0)