|
1 | 1 | import expecttest |
| 2 | +import torch |
| 3 | +import pytest |
2 | 4 |
|
3 | 5 | from triton import knobs |
4 | 6 | from triton.experimental import gluon |
5 | 7 | from triton.experimental.gluon import language as ttgl |
6 | 8 | from triton._filecheck import filecheck_test |
7 | 9 | import triton.language as tl |
| 10 | +from triton._internal_testing import is_cuda |
8 | 11 |
|
9 | 12 |
|
10 | 13 | @gluon.jit |
@@ -39,7 +42,7 @@ def test_convert_layout(fresh_knobs): |
39 | 42 | def shared_memory_kernel(XBLOCK: ttgl.constexpr, YBLOCK: ttgl.constexpr, layout_a: ttgl.constexpr, |
40 | 43 | layout_b: ttgl.constexpr, smem_layout: ttgl.constexpr): |
41 | 44 | a = ttgl.full([XBLOCK, YBLOCK], 0, ttgl.int32, layout_a) |
42 | | - mem = ttgl.allocate_shared(ttgl.int32, a.shape, smem_layout, a) |
| 45 | + mem = ttgl.allocate_shared_memory(ttgl.int32, a.shape, smem_layout, a) |
43 | 46 | b = mem.load(layout_b) # noqa: F841 |
44 | 47 | mem.store(a) |
45 | 48 |
|
@@ -72,6 +75,47 @@ def test_shared_memory(fresh_knobs): |
72 | 75 | """) |
73 | 76 |
|
74 | 77 |
|
| 78 | +@gluon.jit |
| 79 | +def tensor_memory_kernel(layout: ttgl.constexpr, tmem_layout: ttgl.constexpr): |
| 80 | + XBLOCK: ttgl.constexpr = tmem_layout.block[0] |
| 81 | + YBLOCK: ttgl.constexpr = tmem_layout.block[1] |
| 82 | + a = ttgl.full([XBLOCK, YBLOCK], 0, ttgl.int32, layout) |
| 83 | + mem = ttgl.nvidia.blackwell.allocate_tensor_memory(ttgl.int32, a.shape, tmem_layout, a) |
| 84 | + b = mem.load(layout) # noqa: F841 |
| 85 | + mem.store(a) |
| 86 | + slice1 = mem.subslice(0, YBLOCK // 2) # noqa: F841 |
| 87 | + slice2 = mem.subslice(YBLOCK // 2, YBLOCK // 2) # noqa: F841 |
| 88 | + |
| 89 | + |
| 90 | +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] != 10, |
| 91 | + reason="Requires blackwell tensor cores") |
| 92 | +def test_tensor_memory(fresh_knobs): |
| 93 | + knobs.compilation.disable_line_info = True |
| 94 | + |
| 95 | + layout = ttgl.BlockedLayout(size_per_thread=[1, 64], threads_per_warp=[32, 1], warps_per_cta=[4, 1], order=[0, 1]) |
| 96 | + tmem_layout = ttgl.nvidia.blackwell.TensorMemoryLayout(block=[128, 128], unpacked=True) |
| 97 | + h = tensor_memory_kernel.warmup(layout, tmem_layout, num_warps=4, grid=(1, )) |
| 98 | + expecttest.assert_expected_inline( |
| 99 | + h.asm["ttgir"], """\ |
| 100 | +#blocked = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> |
| 101 | +#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true> |
| 102 | +module attributes {"ttg.num-warps" = 4 : i32} { |
| 103 | + tt.func public @tensor_memory_kernel() attributes {noinline = false} { |
| 104 | + %c0_i32 = arith.constant 0 : i32 loc(#loc) |
| 105 | + %cst = arith.constant dense<0> : tensor<128x128xi32, #blocked> loc(#loc) |
| 106 | + %result = ttng.tmem_alloc %cst : (tensor<128x128xi32, #blocked>) -> !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> loc(#loc) |
| 107 | + %result_0 = ttng.tmem_load %result : !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xi32, #blocked> loc(#loc) |
| 108 | + %true = arith.constant true loc(#loc) |
| 109 | + ttng.tmem_store %cst, %result, %true : tensor<128x128xi32, #blocked> -> !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> loc(#loc) |
| 110 | + %0 = ttng.tmem_subslice %result {N = 0 : i32} : !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x64xi32, #tmem, #ttng.tensor_memory, mutable, 128x128> loc(#loc) |
| 111 | + %1 = ttng.tmem_subslice %result {N = 64 : i32} : !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x64xi32, #tmem, #ttng.tensor_memory, mutable, 128x128> loc(#loc) |
| 112 | + tt.return loc(#loc) |
| 113 | + } loc(#loc) |
| 114 | +} loc(#loc) |
| 115 | +#loc = loc(unknown) |
| 116 | +""") |
| 117 | + |
| 118 | + |
75 | 119 | @gluon.jit |
76 | 120 | def warp_specialize_default(a, b): |
77 | 121 | return b, a |
|
0 commit comments