@@ -447,6 +447,46 @@ def test_tcgen05_mma(fresh_knobs):
447447""" )
448448
449449
450+ @gluon .jit
451+ def tcgen05_mma_mbar_kernel (nvmma_layout : ttgl .constexpr , acc_layout : ttgl .constexpr ):
452+ a = ttgl .allocate_shared_memory (ttgl .float16 , [128 , 128 ], nvmma_layout )
453+ b = ttgl .allocate_shared_memory (ttgl .float16 , [128 , 128 ], nvmma_layout )
454+ bar = ttgl .allocate_shared_memory (ttgl .int64 , [1 ], mbarrier .MBarrierLayout ())
455+ acc = blackwell .allocate_tensor_memory (ttgl .float16 , [128 , 128 ], acc_layout )
456+ blackwell .tcgen05_mma (a , b , acc , mbarriers = [bar ])
457+
458+
459+ @pytest .mark .skipif (not is_blackwell (), reason = "Requires blackwell tensor core" )
460+ def test_tcgen05_mma_mbar (fresh_knobs ):
461+ knobs .compilation .disable_line_info = True
462+
463+ nvmma_layout = ttgl .NVMMASharedLayout (swizzle_byte_width = 128 , element_bitwidth = 16 , rank = 2 )
464+ acc_layout = TensorMemoryLayout ([128 , 128 ], unpacked = True )
465+
466+ h = tcgen05_mma_mbar_kernel .warmup (nvmma_layout , acc_layout , grid = (1 , ))
467+ expecttest .assert_expected_inline (
468+ anonymize_ir (h .asm ["source" ]), """\
469+ #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
470+ #shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
471+ #smem = #ttg.shared_memory
472+ #tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
473+ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
474+ tt.func public @tcgen05_mma_mbar_kernel() attributes {noinline = false} {
475+ %0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
476+ %1 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
477+ %2 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
478+ %result = ttng.tmem_alloc : () -> !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory, mutable> loc(#loc)
479+ %true = arith.constant true loc(#loc)
480+ %true_0 = arith.constant true loc(#loc)
481+ %true_1 = arith.constant true loc(#loc)
482+ %3 = ttng.tc_gen5_mma %0, %1, %result[], %true, %true_0, %2[%true_1] : !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
483+ tt.return loc(#loc)
484+ } loc(#loc)
485+ } loc(#loc)
486+ #loc = loc(unknown)
487+ """ )
488+
489+
450490@filecheck_test
451491@gluon .jit
452492def test_tcgen05_commit ():
0 commit comments