Skip to content

Commit fa73f39

Browse files
authored
[Gluon][TritonNvidiaGPU] Add and expose tcgen05.commit (#7335)
This adds a separate op that maps to `tcgen05.commit`. When writing persistent kernels, it is far simpler and more performant to be able to enqueue arrival on an mbarrier in the persistent loop epilogue using a separate commit op than try to figure out which specific MMA needs to arrive on the barrier. E.g. ```python for tile_id in range(...): mma # peeled MMA for _ in range(num_iters): mma mma for _ in range(num_masked_iters): mma mma commit ``` This will also be important when doing warp specialization on nested control flow (cc @masahi @mbrookhart ) This PR also slightly optimizes the codegen for selecting warp 0 when there is only 1 warp by calling `getWarpId`. This removes a few instructions in the MMA loop for a warp specialized kernel.
1 parent 5a6a9a7 commit fa73f39

File tree

8 files changed

+113
-9
lines changed

8 files changed

+113
-9
lines changed

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,52 @@ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [
523523
}];
524524
}
525525

526+
def TTNG_TCGen5CommitOp : TTNG_Op<"tc_gen5_commit"> {
527+
let summary = "make an mbarrier track completion of all prior async tcgen5 ops";
528+
529+
let description = [{
530+
The `ttng.tc_gen5_commit` is an asynchronous operation that makes the
531+
mbarrier object track the completion of all prior asynchronous tcgen5
532+
operations. Upon completion of all asynchronous operations, the mbarrier
533+
arrive operation is performed on the mbarrier with a count of 1.
534+
535+
If `two_ctas` is set, then the mbarrier tracks all prior operations
536+
initiated with `two_ctas` set as well. Otherwise, it tracks all prior
537+
operations initiated without `two_ctas`.
538+
539+
Note that the completion mechanisms are guaranteed to occur sequentially in
540+
the order the commit operations were issued. This means, for example:
541+
542+
```mlir
543+
ttng.tmem_copy
544+
ttng.tc_gen5_mma
545+
ttng.tc_gen5_commit %barrierA
546+
ttng.tc_gen5_commit %barrierB
547+
```
548+
549+
`%barrierA` tracks the completion of the previous TMEM copy and MMA
550+
operations, but since the commit groups are sequential, the arrive-on
551+
operation on `%barrierA` is guaranteed to be performed before the arrive-on
552+
operation on `%barrierB`, even though its commit group is empty.
553+
}];
554+
555+
let arguments = (ins
556+
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$barrier,
557+
Optional<I1>:$pred,
558+
UnitAttr:$two_ctas
559+
);
560+
561+
let assemblyFormat = [{
562+
$barrier (`,` $pred^)? attr-dict `:` qualified(type($barrier))
563+
}];
564+
565+
let builders = [
566+
OpBuilder<(ins "Value":$barrier, CArg<"bool", "false">:$two_ctas), [{
567+
build($_builder, $_state, barrier, /*pred=*/Value(), two_ctas);
568+
}]>,
569+
];
570+
}
571+
526572
def TTNG_TMEMLoadOp : TTNG_Op<"tmem_load"> {
527573
let summary = "Load a buffer from tensor memory into a distributed tensor";
528574

python/src/gluon_ir.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,10 @@ void init_gluon_ir(py::module &&m) {
448448
pred, two_ctas, mbarriers,
449449
mbarrier_preds);
450450
})
451+
.def("create_tcgen05_commit",
452+
[](GluonOpBuilder &self, Value &barrier) {
453+
self.create<ttng::TCGen5CommitOp>(barrier);
454+
})
451455

452456
.def("create_async_tma_copy_global_to_local",
453457
[](GluonOpBuilder &self, Value descPtr, std::vector<Value> &coord,

python/test/gluon/test_frontend.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,16 @@ def test_tcgen05_mma(fresh_knobs):
447447
""")
448448

449449

450+
@filecheck_test
451+
@gluon.jit
452+
def test_tcgen05_commit():
453+
# CHECK-LABEL: test_tcgen05_commit
454+
barrier = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout())
455+
# CHECK: [[BARRIER:%.*]] = ttg.local_alloc
456+
# CHECK: ttng.tc_gen5_commit [[BARRIER]]
457+
blackwell.tcgen05_commit(barrier)
458+
459+
450460
@gluon.jit
451461
def warpgroup_mma_kernel(nvmma_layout: ttgl.constexpr, acc_layout: ttgl.constexpr):
452462
a = ttgl.allocate_shared_memory(ttgl.float16, [128, 128], nvmma_layout)

python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,3 +286,8 @@ def tcgen05_mma(a, b, acc, *, use_acc=True, pred=True, mbarriers=None, mbarrier_
286286

287287
_semantic.builder.create_tcgen05_mma(a.handle, b.handle, acc.handle, use_acc.handle, pred.handle, mbarriers,
288288
mbarrier_preds)
289+
290+
291+
@builtin
292+
def tcgen05_commit(barrier, _semantic=None):
293+
_semantic.builder.create_tcgen05_commit(barrier.handle)

test/Conversion/tritongpu_to_llvm_blackwell.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,25 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
595595

596596
// -----
597597

598+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
599+
#smem = #ttg.shared_memory
600+
601+
module attributes {"ttg.num-warps" = 1 : i32} {
602+
// CHECK-LABEL: @tc_gen5_commit
603+
tt.func @tc_gen5_commit(%arg0: !ttg.memdesc<1xi64, #shared, #smem, mutable>, %pred: i1) {
604+
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32)
605+
// CHECK: [[IS_WARP_0:%.*]] = llvm.icmp "eq" [[ZERO]], [[ZERO]]
606+
// CHECK: [[ELECT:%.*]] = nvvm.elect.sync
607+
// CHECK: [[WARP_PRED:%.*]] = llvm.and [[IS_WARP_0]], [[ELECT]]
608+
// CHECK: [[PRED:%.*]] = llvm.and %arg1, [[WARP_PRED]]
609+
// CHECK: @$0 tcgen05.commit.cta_group::1.mbarrier::arrive::one.b64 [$1];", "b,l" [[PRED]]
610+
ttng.tc_gen5_commit %arg0, %pred : !ttg.memdesc<1xi64, #shared, #smem, mutable>
611+
tt.return
612+
}
613+
}
614+
615+
// -----
616+
598617
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
599618

600619
module attributes {"ttg.num-warps" = 4 : i32} {

third_party/nvidia/backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ def make_ttgir(mod, metadata, opt, capability):
294294
if capability // 10 >= 9:
295295
nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
296296
nvidia.passes.ttnvgpuir.add_fence_insertion(pm, capability)
297+
nvidia.passes.ttnvgpuir.add_lower_mma(pm)
297298
passes.common.add_sccp(pm)
298299
passes.common.add_canonicalizer(pm)
299300
pm.run(mod)
@@ -325,7 +326,6 @@ def make_llir(self, src, metadata, options, capability):
325326
pm = ir.pass_manager(mod.context)
326327
pm.enable_debug()
327328

328-
nvidia.passes.ttnvgpuir.add_lower_mma(pm)
329329
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
330330
passes.ttgpuir.add_allocate_warp_groups(pm)
331331
passes.convert.add_scf_to_cf(pm)

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -678,8 +678,6 @@ struct TCGen5MMAOpConversion
678678
"Operand A should use Shared or Tensor memory layout.");
679679
assert(isa<NVMMASharedEncodingAttr>(BEnc) &&
680680
"Operand B should use Shared layout.");
681-
assert(!op.getBarriers().empty() &&
682-
"tensorcore op should have a barrier at this point.");
683681
convertDot(*getTypeConverter(), rewriter, op.getLoc(), op, adaptor);
684682
rewriter.eraseOp(op);
685683
return success();
@@ -693,14 +691,36 @@ struct TCGen5MMAScaledOpConversion
693691
LogicalResult
694692
matchAndRewrite(ttng::TCGen5MMAScaledOp op, OpAdaptor adaptor,
695693
ConversionPatternRewriter &rewriter) const override {
696-
assert(!op.getBarriers().empty() &&
697-
"tensorcore op should have a barrier at this point.");
698694
convertScaledDot(*getTypeConverter(), rewriter, op.getLoc(), op, adaptor);
699695
rewriter.eraseOp(op);
700696
return success();
701697
}
702698
};
703699

700+
struct TCGen5CommitOpConversion
701+
: public ConvertOpToLLVMPattern<ttng::TCGen5CommitOp> {
702+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
703+
704+
LogicalResult
705+
matchAndRewrite(ttng::TCGen5CommitOp op, OpAdaptor adaptor,
706+
ConversionPatternRewriter &rewriter) const override {
707+
Location loc = op.getLoc();
708+
TritonLLVMOpBuilder b(loc, rewriter);
709+
710+
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
711+
loc, adaptor.getBarrier(), rewriter.getI64Type(), rewriter);
712+
Value pred = LLVM::NVIDIA::createElectPredicateWarp0(loc, rewriter);
713+
714+
if (adaptor.getPred())
715+
pred = b.and_(adaptor.getPred(), pred);
716+
717+
createMMACommit(rewriter, op.getLoc(), smemObj.getBase(), pred,
718+
op.getTwoCtas());
719+
rewriter.eraseOp(op);
720+
return success();
721+
}
722+
};
723+
704724
} // namespace
705725

706726
namespace mlir {
@@ -710,8 +730,8 @@ namespace NVIDIA {
710730
void populateTCGen5MMAOpToLLVMPattern(LLVMTypeConverter &typeConverter,
711731
RewritePatternSet &patterns,
712732
PatternBenefit benefit) {
713-
patterns.add<TCGen5MMAOpConversion, TCGen5MMAScaledOpConversion>(
714-
typeConverter, benefit);
733+
patterns.add<TCGen5MMAOpConversion, TCGen5MMAScaledOpConversion,
734+
TCGen5CommitOpConversion>(typeConverter, benefit);
715735
}
716736

717737
} // namespace NVIDIA

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ void createSyncWarp(Location loc, OpBuilder &rewriter) {
127127

128128
Value createElectPredicateWarp0(Location loc, RewriterBase &rewriter) {
129129
auto b = TritonLLVMOpBuilder(loc, rewriter);
130-
Value threadId = getThreadId(rewriter, loc);
131-
Value warp0 = b.icmp_ult(threadId, b.i32_val(32));
130+
Value warpId = getLaneAndWarpId(rewriter, loc).second;
131+
Value warp0 = b.icmp_eq(warpId, b.i32_val(0));
132132
return b.and_(warp0, createElectPredicate(loc, rewriter));
133133
}
134134

0 commit comments

Comments
 (0)