Skip to content

Commit a0d489a

Browse files
authored
[TRTLLM-7728][perf] improve batched sampling perf for contiguous batches (#7908)
Signed-off-by: ixlmar <[email protected]>
1 parent 560ded5 commit a0d489a

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,7 @@ def _group_requests_by_sampling_strategy(
396396
requests: Iterable[LlmRequest],
397397
*,
398398
pin_memory: bool = False) -> dict[Strategy, torch.Tensor]:
399+
# NB: Client code relies on request indices in returned torch.Tensor being sorted.
399400
strategy_dict: dict[Strategy, list[int]] = defaultdict(list)
400401
for req_index, req in enumerate(requests):
401402
strategy_dict[_request_strategy(req)].append(req_index)
@@ -1372,12 +1373,20 @@ def _sample_batched_by_strategy(
13721373
len(speculation_group_indices), dtype=torch.int32)
13731374

13741375
group_logits_cuda_indices = logits_cuda_indexer[group_req_indices]
1375-
if group_logits_cuda_indices.numel() != logits_cuda.size(0):
1376+
# NB: Assuming that group_req_indices are sorted
1377+
group_req_1st_index, group_req_last_index = group_req_indices[
1378+
0], group_req_indices[-1]
1379+
if group_req_last_index - group_req_1st_index + 1 == len(
1380+
group_req_indices):
1381+
# Avoid data movement if indices are contiguous
1382+
group_logits_cuda = logits_cuda[
1383+
req_offsets[group_req_1st_index]:(
1384+
req_offsets[group_req_last_index] +
1385+
req_num_steps[group_req_last_index])]
1386+
else:
13761387
group_logits_cuda_indices_cuda = group_logits_cuda_indices.to(
13771388
device=logits_cuda.device, non_blocking=True)
13781389
group_logits_cuda = logits_cuda[group_logits_cuda_indices_cuda]
1379-
else:
1380-
group_logits_cuda = logits_cuda
13811390

13821391
# Indexer for accessing tokens in 'group_logits_cuda' (and 'group_next_tokens_cuda')
13831392
# corresponding to the requests in 'group_req_indices'.

0 commit comments

Comments
 (0)