|
1 | 1 | import pytest |
2 | 2 | import torch |
| 3 | +from utils.util import getSMVersion, skip_pre_hopper |
3 | 4 |
|
4 | 5 | # Import tensorrt_llm to load custom CUDA operators (indexer_topk_decode, indexer_topk_prefill) |
5 | 6 | import tensorrt_llm # noqa: F401 |
6 | 7 |
|
| 8 | +if not torch.cuda.is_available(): |
| 9 | + pytest.skip("CUDA is required for indexer_topk tests", allow_module_level=True) |
| 10 | + |
| 11 | + |
| 12 | +def _prefill_param_values(): |
| 13 | + """ |
| 14 | + Decide parameter coverage based on GPU architecture (SM version). |
| 15 | +
|
| 16 | + - pre-Hopper (SM < 90): skip via @skip_pre_hopper |
| 17 | + - Hopper (SM == 90): reduced coverage |
| 18 | + - Blackwell (SM >= 100): full coverage |
| 19 | + """ |
| 20 | + sm = getSMVersion() |
| 21 | + if sm >= 100: # Blackwell family |
| 22 | + return [1, 32], [4096, 8192, 32768] |
| 23 | + # Hopper (and other >= 90 but < 100, if any): reduced coverage |
| 24 | + return [1, 4], [4096, 8192, 32768] |
| 25 | + |
| 26 | + |
| 27 | +_PREFILL_BATCH_SIZES, _PREFILL_NUM_TOKENS = _prefill_param_values() |
| 28 | + |
7 | 29 |
|
8 | 30 | def create_random_logits( |
9 | 31 | row_starts: torch.Tensor, |
@@ -197,27 +219,38 @@ def test_indexer_topk_decode(batch_size, next_n, index_topk, num_tokens): |
197 | 219 | ), "CUDA top_k_per_row results don't match torch.topk" |
198 | 220 |
|
199 | 221 |
|
200 | | -@pytest.mark.parametrize("batch_size", [1, 512, 2048]) |
| 222 | +@skip_pre_hopper |
| 223 | +@pytest.mark.parametrize("batch_size", _PREFILL_BATCH_SIZES) |
201 | 224 | @pytest.mark.parametrize("index_topk", [2048, 128]) |
202 | | -@pytest.mark.parametrize("num_tokens", [4096, 8192]) |
| 225 | +@pytest.mark.parametrize("num_tokens", _PREFILL_NUM_TOKENS) |
203 | 226 | def test_indexer_topk_prefill(batch_size, index_topk, num_tokens): |
204 | 227 | torch.manual_seed(24) |
205 | 228 | torch.cuda.manual_seed(24) |
206 | 229 |
|
207 | | - # Set input data |
208 | | - row_starts = torch.zeros(batch_size, dtype=torch.int32, device="cuda") |
209 | | - row_ends = torch.arange(1, batch_size + 1, device="cuda", dtype=torch.int32) |
| 230 | + # gen random input for the sequence length |
| 231 | + seq_lens = generate_seq_lens(batch_size, index_topk, num_tokens) |
| 232 | + num_gen_tokens = seq_lens.sum() |
| 233 | + |
| 234 | + # gen the row_starts and row_ends (from 1 to ...) |
| 235 | + row_starts = torch.zeros(num_gen_tokens, dtype=torch.int32, device="cuda") |
| 236 | + row_indices = torch.arange(1, seq_lens.max() + 1, dtype=torch.int32, device="cuda") |
| 237 | + row_ends = row_indices.expand(seq_lens.size(0), -1)[ |
| 238 | + row_indices.expand(seq_lens.size(0), -1) <= seq_lens.unsqueeze(1) |
| 239 | + ].contiguous() |
210 | 240 |
|
| 241 | + # gen logits |
211 | 242 | logits = create_random_logits(row_starts, row_ends, torch.float32, 42) |
212 | 243 |
|
213 | 244 | # Create output tensors |
214 | | - indices = torch.empty((batch_size, index_topk), dtype=torch.int32, device="cuda") |
| 245 | + indices = torch.empty((num_gen_tokens, index_topk), dtype=torch.int32, device="cuda") |
215 | 246 |
|
216 | 247 | # Run CUDA implementation |
217 | 248 | torch.ops.trtllm.indexer_topk_prefill(logits, row_starts, row_ends, indices, index_topk) |
| 249 | + torch.cuda.synchronize() |
218 | 250 |
|
219 | 251 | # Run reference implementation |
220 | | - torch_indices = logits.topk(min(index_topk, max(row_ends)), dim=-1)[1] |
| 252 | + max_row_len = row_ends.max().item() |
| 253 | + torch_indices = logits.topk(min(index_topk, max_row_len), dim=-1)[1] |
221 | 254 | mask_lo = torch_indices >= 0 |
222 | 255 | mask_hi = (torch_indices - (row_ends - row_starts)[:, None]) < 0 |
223 | 256 | mask = mask_lo & mask_hi |
|
0 commit comments