Skip to content

Commit b882393

Browse files
[https://nvbugs/5720357][fix] Fix indice offset overflow in custom Top-K kernel and corresponding UT case (#10027)
Signed-off-by: longcheng-nv <243710427+longcheng-nv@users.noreply.github.com> Co-authored-by: Chang Liu (Enterprise Products) <9713593+chang-l@users.noreply.github.com>
1 parent dfa11d8 commit b882393

File tree

2 files changed

+48
-15
lines changed

2 files changed

+48
-15
lines changed

cpp/tensorrt_llm/kernels/indexerTopK.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -606,8 +606,8 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowPrefill(
606606
int rowEnd = rowEnds[rowIdx];
607607

608608
// Local pointers to this block
609-
outIndices += rowIdx * topK;
610-
logits += rowIdx * stride0;
609+
outIndices += static_cast<int64_t>(rowIdx) * topK;
610+
logits += static_cast<int64_t>(rowIdx) * stride0;
611611

612612
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort>(
613613
nullptr, logits, rowStart, rowEnd, outIndices, nullptr, stride1, topK);
@@ -638,23 +638,23 @@ static __global__ __launch_bounds__(kNumThreadsPerBlock) void topKPerRowDecode(f
638638
// Local pointers to this block
639639
if constexpr (!multipleBlocksPerRow && !mergeBlocks)
640640
{
641-
outIndices += rowIdx * topK;
641+
outIndices += static_cast<int64_t>(rowIdx) * topK;
642642
}
643643
else if constexpr (multipleBlocksPerRow)
644644
{
645645
auto const blockSize = rowEnd / gridDim.y; // 16384 / 2 = 8192
646646
rowStart = blockSize * blockIdx.y; // 8192 * 1 = 8192
647647
rowEnd = gridDim.y == blockIdx.y + 1 ? rowEnd : rowStart + blockSize;
648-
outIndices += rowIdx * gridDim.y * topK + blockIdx.y * topK;
649-
outLogits += rowIdx * gridDim.y * topK + blockIdx.y * topK;
648+
outIndices += static_cast<int64_t>(rowIdx) * gridDim.y * topK + blockIdx.y * topK;
649+
outLogits += static_cast<int64_t>(rowIdx) * gridDim.y * topK + blockIdx.y * topK;
650650
}
651651
else if constexpr (mergeBlocks)
652652
{
653653
rowEnd = numBlocksToMerge * topK;
654-
indices += rowIdx * numBlocksToMerge * topK;
655-
outIndices += rowIdx * topK;
654+
indices += static_cast<int64_t>(rowIdx) * numBlocksToMerge * topK;
655+
outIndices += static_cast<int64_t>(rowIdx) * topK;
656656
}
657-
logits += rowIdx * stride0;
657+
logits += static_cast<int64_t>(rowIdx) * stride0;
658658

659659
topKPerRowJob<kNumThreadsPerBlock, kNumBins, useRadixSort, multipleBlocksPerRow, mergeBlocks>(
660660
indices, logits, rowStart, rowEnd, outIndices, outLogits, stride1, topK);

tests/unittest/_torch/thop/parallel/test_indexer_topk.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,31 @@
11
import pytest
22
import torch
3+
from utils.util import getSMVersion, skip_pre_hopper
34

45
# Import tensorrt_llm to load custom CUDA operators (indexer_topk_decode, indexer_topk_prefill)
56
import tensorrt_llm # noqa: F401
67

8+
if not torch.cuda.is_available():
9+
pytest.skip("CUDA is required for indexer_topk tests", allow_module_level=True)
10+
11+
12+
def _prefill_param_values():
13+
"""
14+
Decide parameter coverage based on GPU architecture (SM version).
15+
16+
- pre-Hopper (SM < 90): skip via @skip_pre_hopper
17+
- Hopper (SM == 90): reduced coverage
18+
- Blackwell (SM >= 100): full coverage
19+
"""
20+
sm = getSMVersion()
21+
if sm >= 100: # Blackwell family
22+
return [1, 32], [4096, 8192, 32768]
23+
# Hopper (and other >= 90 but < 100, if any): reduced coverage
24+
return [1, 4], [4096, 8192, 32768]
25+
26+
27+
_PREFILL_BATCH_SIZES, _PREFILL_NUM_TOKENS = _prefill_param_values()
28+
729

830
def create_random_logits(
931
row_starts: torch.Tensor,
@@ -197,27 +219,38 @@ def test_indexer_topk_decode(batch_size, next_n, index_topk, num_tokens):
197219
), "CUDA top_k_per_row results don't match torch.topk"
198220

199221

200-
@pytest.mark.parametrize("batch_size", [1, 512, 2048])
222+
@skip_pre_hopper
223+
@pytest.mark.parametrize("batch_size", _PREFILL_BATCH_SIZES)
201224
@pytest.mark.parametrize("index_topk", [2048, 128])
202-
@pytest.mark.parametrize("num_tokens", [4096, 8192])
225+
@pytest.mark.parametrize("num_tokens", _PREFILL_NUM_TOKENS)
203226
def test_indexer_topk_prefill(batch_size, index_topk, num_tokens):
204227
torch.manual_seed(24)
205228
torch.cuda.manual_seed(24)
206229

207-
# Set input data
208-
row_starts = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
209-
row_ends = torch.arange(1, batch_size + 1, device="cuda", dtype=torch.int32)
230+
# gen random input for the sequence length
231+
seq_lens = generate_seq_lens(batch_size, index_topk, num_tokens)
232+
num_gen_tokens = seq_lens.sum()
233+
234+
# gen the row_starts and row_ends (from 1 to ...)
235+
row_starts = torch.zeros(num_gen_tokens, dtype=torch.int32, device="cuda")
236+
row_indices = torch.arange(1, seq_lens.max() + 1, dtype=torch.int32, device="cuda")
237+
row_ends = row_indices.expand(seq_lens.size(0), -1)[
238+
row_indices.expand(seq_lens.size(0), -1) <= seq_lens.unsqueeze(1)
239+
].contiguous()
210240

241+
# gen logits
211242
logits = create_random_logits(row_starts, row_ends, torch.float32, 42)
212243

213244
# Create output tensors
214-
indices = torch.empty((batch_size, index_topk), dtype=torch.int32, device="cuda")
245+
indices = torch.empty((num_gen_tokens, index_topk), dtype=torch.int32, device="cuda")
215246

216247
# Run CUDA implementation
217248
torch.ops.trtllm.indexer_topk_prefill(logits, row_starts, row_ends, indices, index_topk)
249+
torch.cuda.synchronize()
218250

219251
# Run reference implementation
220-
torch_indices = logits.topk(min(index_topk, max(row_ends)), dim=-1)[1]
252+
max_row_len = row_ends.max().item()
253+
torch_indices = logits.topk(min(index_topk, max_row_len), dim=-1)[1]
221254
mask_lo = torch_indices >= 0
222255
mask_hi = (torch_indices - (row_ends - row_starts)[:, None]) < 0
223256
mask = mask_lo & mask_hi

0 commit comments

Comments
 (0)