55from triton import knobs
66from triton .experimental import gluon
77from triton .experimental .gluon import language as ttgl
8+ from triton .experimental .gluon .language .nvidia import blackwell
89from triton .experimental .gluon .language .nvidia .blackwell import mbarrier
910from triton ._filecheck import filecheck_test
1011import triton .language as tl
@@ -85,6 +86,7 @@ def tensor_memory_kernel(layout: ttgl.constexpr, tmem_layout: ttgl.constexpr):
8586 XBLOCK : ttgl .constexpr = tmem_layout .block [0 ]
8687 YBLOCK : ttgl .constexpr = tmem_layout .block [1 ]
8788 a = ttgl .full ([XBLOCK , YBLOCK ], 0 , ttgl .int32 , layout )
89+ _ = ttgl .nvidia .blackwell .allocate_tensor_memory (ttgl .int32 , a .shape , tmem_layout )
8890 mem = ttgl .nvidia .blackwell .allocate_tensor_memory (ttgl .int32 , a .shape , tmem_layout , a )
8991 b = mem .load (layout ) # noqa: F841
9092 mem .store (a )
@@ -108,12 +110,13 @@ def test_tensor_memory(fresh_knobs):
108110 tt.func public @tensor_memory_kernel() attributes {noinline = false} {
109111 %c0_i32 = arith.constant 0 : i32 loc(#loc)
110112 %cst = arith.constant dense<0> : tensor<128x128xi32, #blocked> loc(#loc)
111- %result = ttng.tmem_alloc %cst : (tensor<128x128xi32, #blocked>) -> !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> loc(#loc)
112- %result_0 = ttng.tmem_load %result : !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xi32, #blocked> loc(#loc)
113+ %result = ttng.tmem_alloc : () -> !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> loc(#loc)
114+ %result_0 = ttng.tmem_alloc %cst : (tensor<128x128xi32, #blocked>) -> !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> loc(#loc)
115+ %result_1 = ttng.tmem_load %result_0 : !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xi32, #blocked> loc(#loc)
113116 %true = arith.constant true loc(#loc)
114- ttng.tmem_store %cst, %result , %true : tensor<128x128xi32, #blocked> -> !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> loc(#loc)
115- %0 = ttng.tmem_subslice %result {N = 0 : i32} : !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x64xi32, #tmem, #ttng.tensor_memory, mutable, 128x128> loc(#loc)
116- %1 = ttng.tmem_subslice %result {N = 64 : i32} : !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x64xi32, #tmem, #ttng.tensor_memory, mutable, 128x128> loc(#loc)
117+ ttng.tmem_store %cst, %result_0 , %true : tensor<128x128xi32, #blocked> -> !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> loc(#loc)
118+ %0 = ttng.tmem_subslice %result_0 {N = 0 : i32} : !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x64xi32, #tmem, #ttng.tensor_memory, mutable, 128x128> loc(#loc)
119+ %1 = ttng.tmem_subslice %result_0 {N = 64 : i32} : !ttg.memdesc<128x128xi32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x64xi32, #tmem, #ttng.tensor_memory, mutable, 128x128> loc(#loc)
117120 tt.return loc(#loc)
118121 } loc(#loc)
119122} loc(#loc)
@@ -217,3 +220,40 @@ def test_mbarrier(fresh_knobs):
217220} loc(#loc)
218221#loc = loc(unknown)
219222""" )
223+
224+
225+ @gluon .jit
226+ def tcgen05_mma_kernel (nvmma_layout : ttgl .constexpr , acc_layout : ttgl .constexpr ):
227+ a = ttgl .allocate_shared_memory (ttgl .float16 , [128 , 128 ], nvmma_layout )
228+ b = ttgl .allocate_shared_memory (ttgl .float16 , [128 , 128 ], nvmma_layout )
229+ acc = blackwell .allocate_tensor_memory (ttgl .float16 , [128 , 128 ], acc_layout )
230+ blackwell .tcgen05_mma (a , b , acc )
231+
232+
233+ @pytest .mark .skipif (not is_cuda () or torch .cuda .get_device_capability ()[0 ] != 10 ,
234+ reason = "Requires blackwell tensor core" )
235+ def test_tcgen05_mma (fresh_knobs ):
236+ knobs .compilation .disable_line_info = True
237+
238+ nvmma_layout = ttgl .NVMMASharedLayout (swizzle_byte_width = 128 , element_bitwidth = 16 , rank = 2 )
239+ acc_layout = blackwell .TensorMemoryLayout ([128 , 128 ], unpacked = True )
240+
241+ h = tcgen05_mma_kernel .warmup (nvmma_layout , acc_layout , grid = (1 , ))
242+ expecttest .assert_expected_inline (
243+ h .asm ["ttgir" ], """\
244+ #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
245+ #smem = #ttg.shared_memory
246+ #tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
247+ module attributes {"ttg.num-warps" = 4 : i32} {
248+ tt.func public @tcgen05_mma_kernel() attributes {noinline = false} {
249+ %0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
250+ %1 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
251+ %result = ttng.tmem_alloc : () -> !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory, mutable> loc(#loc)
252+ %true = arith.constant true loc(#loc)
253+ %true_0 = arith.constant true loc(#loc)
254+ %2 = ttng.tc_gen5_mma %0, %1, %result[], %true, %true_0 : !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory, mutable> loc(#loc)
255+ tt.return loc(#loc)
256+ } loc(#loc)
257+ } loc(#loc)
258+ #loc = loc(unknown)
259+ """ )
0 commit comments