Skip to content

Commit bfb2195

Browse files
committed
[TRTLLM-9687][fix] Enable pinned memory for tensor allocations in TorchSampler
- Updated tensor allocation in TorchSampler to use pinned memory for improved performance during D2H copies. - Modified test_sampled_token_always_in_logprobs to include logprobs_mode parameter for enhanced testing of log probabilities. Signed-off-by: Stefan Niebler <[email protected]>
1 parent 3a28e24 commit bfb2195

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2596,11 +2596,11 @@ def _process_logprobs(
25962596
sampled_rank_cuda = group_logprobs_cuda.sum(dim=-1).to(torch.int32)
25972597

25982598
# Use a single D2H copy to reduce overheads
2599-
topk_vals = torch.empty_like(topk_vals_cuda, device="cpu", pin_memory=False)
2600-
topk_indices = torch.empty_like(topk_indices_cuda, device="cpu", pin_memory=False)
2601-
sampled_vals = torch.empty_like(sampled_vals_cuda, device="cpu", pin_memory=False)
2602-
sampled_indices = torch.empty_like(sampled_indices_cuda, device="cpu", pin_memory=False)
2603-
sampled_rank = torch.empty_like(sampled_rank_cuda, device="cpu", pin_memory=False)
2599+
topk_vals = torch.empty_like(topk_vals_cuda, device="cpu", pin_memory=True)
2600+
topk_indices = torch.empty_like(topk_indices_cuda, device="cpu", pin_memory=True)
2601+
sampled_vals = torch.empty_like(sampled_vals_cuda, device="cpu", pin_memory=True)
2602+
sampled_indices = torch.empty_like(sampled_indices_cuda, device="cpu", pin_memory=True)
2603+
sampled_rank = torch.empty_like(sampled_rank_cuda, device="cpu", pin_memory=True)
26042604

26052605
topk_vals.copy_(topk_vals_cuda, non_blocking=True)
26062606
topk_indices.copy_(topk_indices_cuda, non_blocking=True)

tests/unittest/_torch/sampler/test_logits_logprobs.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,10 @@ def test_generate_async_with_return_logits(
258258

259259
@pytest.mark.parametrize("logprobs_k", [0, 1, 3],
260260
ids=["top_0", "top_1", "top_3"])
261+
@pytest.mark.parametrize("logprobs_mode", ["raw", "processed"])
261262
@pytest.mark.threadleak(enabled=False)
262-
def test_sampled_token_always_in_logprobs(logprobs_k: int, simple_llm: LLM):
263+
def test_sampled_token_always_in_logprobs(logprobs_k: int, logprobs_mode: str,
264+
simple_llm: LLM):
263265
"""Two scenarios:
264266
- logprobs=0: Returns only sampled token (1 element)
265267
- logprobs=K (K>0): Returns top-K tokens + sampled token if not in top-K (up to K+1 elements)
@@ -270,6 +272,7 @@ def test_sampled_token_always_in_logprobs(logprobs_k: int, simple_llm: LLM):
270272
temperature=0.7,
271273
top_p=0.9,
272274
logprobs=logprobs_k,
275+
logprobs_mode=logprobs_mode,
273276
)
274277

275278
for output in simple_llm.generate(["The future of AI is"],
@@ -474,6 +477,8 @@ def test_processed_logprobs_e2e(logprobs_k: int, simple_llm: LLM):
474477
num_logits = len(generation_logits)
475478

476479
for token_idx, token_logprobs_dict in enumerate(logprobs[:num_logits]):
480+
assert token_ids[
481+
token_idx] in token_logprobs_dict, "Sampled token not in logprobs"
477482

478483
logits_for_token = generation_logits[token_idx:token_idx + 1]
479484
topk = sampling_params_list[req_idx].top_k

0 commit comments

Comments
 (0)