|
7 | 7 |
|
8 | 8 | @pytest.mark.parametrize('use_cuda_graph', [False, True]) |
9 | 9 | def test_kwargs(use_cuda_graph: bool, device: str): |
10 | | - N = 1024 |
11 | | - src = torch.randn(N, device=device) |
12 | | - dst = torch.empty(N, device=device) |
| 10 | + M, N = 1024, 16 |
| 11 | + src = torch.randn(M * N, device=device) |
| 12 | + dst = torch.empty(M * N, device=device) |
13 | 13 |
|
14 | | - configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})] |
| 14 | + configs = [triton.Config(kwargs={'BLOCK_SIZE_M': 32}), triton.Config(kwargs={'BLOCK_SIZE_M': 128})] |
15 | 15 |
|
16 | | - @triton.autotune(configs=configs, key=['N'], warmup=1, rep=1, use_cuda_graph=use_cuda_graph) |
| 16 | + @triton.autotune(configs=configs, key=['M'], warmup=1, rep=1, use_cuda_graph=use_cuda_graph) |
17 | 17 | @triton.jit |
18 | | - def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): |
19 | | - offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
20 | | - x = tl.load(src + offsets, mask=offsets < N) |
21 | | - tl.store(dst + offsets, x, mask=offsets < N) |
22 | | - |
23 | | - grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) |
24 | | - _kernel[grid](dst, src, N) |
25 | | - _kernel[grid](dst=dst, src=src, N=N) |
| 18 | + def _kernel(dst, src, stride_m: tl.constexpr, M, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_M: tl.constexpr): |
| 19 | + offsets_m = tl.program_id(0) * stride_m + tl.arange(0, BLOCK_SIZE_M) |
| 20 | + offsets_n = tl.arange(0, BLOCK_SIZE_N) |
| 21 | + x = tl.load(src + offsets_m[:, None] * BLOCK_SIZE_N + offsets_n[None, :]) |
| 22 | + tl.store(dst + offsets_m[:, None] * BLOCK_SIZE_N + offsets_n[None, :], x) |
| 23 | + |
| 24 | + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE_M']), ) |
| 25 | + _kernel[grid](dst, src, N, M, N) |
| 26 | + # the key word args could be in arbitrary order. |
| 27 | + _kernel[grid](dst=dst, src=src, M=M // 2, stride_m=N, BLOCK_SIZE_N=N) |
| 28 | + assert len(_kernel.cache) == 2 |
26 | 29 |
|
27 | 30 |
|
28 | 31 | def test_restore(device): |
|
0 commit comments