5
5
from triton import knobs
6
6
from triton .experimental import gluon
7
7
from triton .experimental .gluon import language as ttgl
8
+ from triton .experimental .gluon .language .nvidia import blackwell
8
9
from triton .experimental .gluon .language .nvidia .blackwell import mbarrier
9
10
from triton ._filecheck import filecheck_test
10
11
import triton .language as tl
@@ -85,6 +86,7 @@ def tensor_memory_kernel(layout: ttgl.constexpr, tmem_layout: ttgl.constexpr):
85
86
XBLOCK : ttgl .constexpr = tmem_layout .block [0 ]
86
87
YBLOCK : ttgl .constexpr = tmem_layout .block [1 ]
87
88
a = ttgl .full ([XBLOCK , YBLOCK ], 0 , ttgl .int32 , layout )
89
+ _ = ttgl .nvidia .blackwell .allocate_tensor_memory (ttgl .int32 , a .shape , tmem_layout )
88
90
mem = ttgl .nvidia .blackwell .allocate_tensor_memory (ttgl .int32 , a .shape , tmem_layout , a )
89
91
b = mem .load (layout ) # noqa: F841
90
92
mem .store (a )
@@ -108,12 +110,13 @@ def test_tensor_memory(fresh_knobs):
108
110
tt.func public @tensor_memory_kernel() attributes {noinline = false} {
109
111
%c0_i32 = arith.constant 0 : i32 loc(#loc)
110
112
%cst = arith.constant dense<0> : tensor<128x128xi32, #blocked> loc(#loc)
111
- %result = ttng.tmem_alloc %cst : (tensor<128x128xi32, #blocked>) -> !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> loc(#loc)
112
- %result_0 = ttng.tmem_load %result : !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xi32, #blocked> loc(#loc)
113
+ %result = ttng.tmem_alloc : () -> !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> loc(#loc)
114
+ %result_0 = ttng.tmem_alloc %cst : (tensor<128x128xi32, #blocked>) -> !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> loc(#loc)
115
+ %result_1 = ttng.tmem_load %result_0 : !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xi32, #blocked> loc(#loc)
113
116
%true = arith.constant true loc(#loc)
114
- ttng.tmem_store %cst, %result , %true : tensor<128x128xi32, #blocked> -> !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> loc(#loc)
115
- %0 = ttng.tmem_subslice %result {N = 0 : i32} : !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x64xi32, #tmem, #ttng.tensor_memory, mutable, 128x128> loc(#loc)
116
- %1 = ttng.tmem_subslice %result {N = 64 : i32} : !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x64xi32, #tmem, #ttng.tensor_memory, mutable, 128x128> loc(#loc)
117
+ ttng.tmem_store %cst, %result_0 , %true : tensor<128x128xi32, #blocked> -> !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> loc(#loc)
118
+ %0 = ttng.tmem_subslice %result_0 {N = 0 : i32} : !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x64xi32, #tmem, #ttng.tensor_memory, mutable, 128x128> loc(#loc)
119
+ %1 = ttng.tmem_subslice %result_0 {N = 64 : i32} : !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x64xi32, #tmem, #ttng.tensor_memory, mutable, 128x128> loc(#loc)
117
120
tt.return loc(#loc)
118
121
} loc(#loc)
119
122
} loc(#loc)
@@ -217,3 +220,40 @@ def test_mbarrier(fresh_knobs):
217
220
} loc(#loc)
218
221
#loc = loc(unknown)
219
222
""" )
223
+
224
+
225
+ @gluon .jit
226
+ def tcgen05_mma_kernel (nvmma_layout : ttgl .constexpr , acc_layout : ttgl .constexpr ):
227
+ a = ttgl .allocate_shared_memory (ttgl .float16 , [128 , 128 ], nvmma_layout )
228
+ b = ttgl .allocate_shared_memory (ttgl .float16 , [128 , 128 ], nvmma_layout )
229
+ acc = blackwell .allocate_tensor_memory (ttgl .float16 , [128 , 128 ], acc_layout )
230
+ blackwell .tcgen05_mma (a , b , acc )
231
+
232
+
233
+ @pytest .mark .skipif (not is_cuda () or torch .cuda .get_device_capability ()[0 ] != 10 ,
234
+ reason = "Requires blackwell tensor core" )
235
+ def test_tcgen05_mma (fresh_knobs ):
236
+ knobs .compilation .disable_line_info = True
237
+
238
+ nvmma_layout = ttgl .NVMMASharedLayout (swizzle_byte_width = 128 , element_bitwidth = 16 , rank = 2 )
239
+ acc_layout = blackwell .TensorMemoryLayout ([128 , 128 ], unpacked = True )
240
+
241
+ h = tcgen05_mma_kernel .warmup (nvmma_layout , acc_layout , grid = (1 , ))
242
+ expecttest .assert_expected_inline (
243
+ h .asm ["ttgir" ], """\
244
+ #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
245
+ #smem = #ttg.shared_memory
246
+ #tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
247
+ module attributes {"ttg.num-warps" = 4 : i32} {
248
+ tt.func public @tcgen05_mma_kernel() attributes {noinline = false} {
249
+ %0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
250
+ %1 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
251
+ %result = ttng.tmem_alloc : () -> !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory, mutable> loc(#loc)
252
+ %true = arith.constant true loc(#loc)
253
+ %true_0 = arith.constant true loc(#loc)
254
+ %2 = ttng.tc_gen5_mma %0, %1, %result[], %true, %true_0 : !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory, mutable> loc(#loc)
255
+ tt.return loc(#loc)
256
+ } loc(#loc)
257
+ } loc(#loc)
258
+ #loc = loc(unknown)
259
+ """ )
0 commit comments