Skip to content

Commit bfaab80

Browse files
[GLUON] Fix and test for an issue with default mbar predicates for MMAv5 (#7432)
In case mbars are provided to `blackwell.tcgen05_mma` while predicates are missing we need to convert python bools to ir values.
1 parent 673ca35 commit bfaab80

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

python/test/gluon/test_frontend.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,46 @@ def test_tcgen05_mma(fresh_knobs):
447447
""")
448448

449449

450+
@gluon.jit
451+
def tcgen05_mma_mbar_kernel(nvmma_layout: ttgl.constexpr, acc_layout: ttgl.constexpr):
452+
a = ttgl.allocate_shared_memory(ttgl.float16, [128, 128], nvmma_layout)
453+
b = ttgl.allocate_shared_memory(ttgl.float16, [128, 128], nvmma_layout)
454+
bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout())
455+
acc = blackwell.allocate_tensor_memory(ttgl.float16, [128, 128], acc_layout)
456+
blackwell.tcgen05_mma(a, b, acc, mbarriers=[bar])
457+
458+
459+
@pytest.mark.skipif(not is_blackwell(), reason="Requires blackwell tensor core")
460+
def test_tcgen05_mma_mbar(fresh_knobs):
461+
knobs.compilation.disable_line_info = True
462+
463+
nvmma_layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2)
464+
acc_layout = TensorMemoryLayout([128, 128], unpacked=True)
465+
466+
h = tcgen05_mma_mbar_kernel.warmup(nvmma_layout, acc_layout, grid=(1, ))
467+
expecttest.assert_expected_inline(
468+
anonymize_ir(h.asm["source"]), """\
469+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
470+
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
471+
#smem = #ttg.shared_memory
472+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
473+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
474+
tt.func public @tcgen05_mma_mbar_kernel() attributes {noinline = false} {
475+
%0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
476+
%1 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
477+
%2 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
478+
%result = ttng.tmem_alloc : () -> !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory, mutable> loc(#loc)
479+
%true = arith.constant true loc(#loc)
480+
%true_0 = arith.constant true loc(#loc)
481+
%true_1 = arith.constant true loc(#loc)
482+
%3 = ttng.tc_gen5_mma %0, %1, %result[], %true, %true_0, %2[%true_1] : !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
483+
tt.return loc(#loc)
484+
} loc(#loc)
485+
} loc(#loc)
486+
#loc = loc(unknown)
487+
""")
488+
489+
450490
@filecheck_test
451491
@gluon.jit
452492
def test_tcgen05_commit():

python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def tcgen05_mma(a, b, acc, *, use_acc=True, pred=True, mbarriers=None, mbarrier_
280280
mbarriers = [bar.handle for bar in mbarriers]
281281
if mbarrier_preds is None:
282282
true = _semantic.to_tensor(True)
283-
mbarrier_preds = [true] * len(mbarriers)
283+
mbarrier_preds = [true.handle] * len(mbarriers)
284284
else:
285285
mbarrier_preds = _semantic._convert_to_ir_values(mbarrier_preds, require_i64=False)
286286

0 commit comments

Comments
 (0)