Skip to content

Commit 4e6c423

Browse files
authored
[BACKEND] Set 2CTA mode as a global flag (#8653)
We do so by looking at the flags of the `tcgen05.mma` dots and we make sure they all agree. Once we support this mode in `dot_scaled` we'll check these as well.
1 parent a21bbbc commit 4e6c423

File tree

10 files changed

+119
-26
lines changed

10 files changed

+119
-26
lines changed

include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
2828
#include "mlir/Dialect/Tensor/IR/Tensor.h"
29+
#include "mlir/IR/BuiltinAttributes.h"
2930
#include "mlir/IR/BuiltinOps.h"
3031
#include "mlir/IR/BuiltinTypes.h"
3132
#include "mlir/IR/Dialect.h"
@@ -51,6 +52,17 @@ LogicalResult verifyMMAv5Op(Operation *op);
5152

5253
namespace mlir::triton::nvidia_gpu {
5354

55+
constexpr static char AttrTwoCTAsName[] = "ttng.two-ctas";
56+
57+
inline bool getModuleTwoCTAs(ModuleOp mod) {
58+
auto attr = mod->getAttrOfType<BoolAttr>(AttrTwoCTAsName);
59+
return attr ? attr.getValue() : false;
60+
}
61+
62+
inline bool getModuleTwoCTAs(Operation *op) {
63+
return getModuleTwoCTAs(op->getParentOfType<ModuleOp>());
64+
}
65+
5466
struct TensorMemory : public SideEffects::Resource::Base<TensorMemory> {
5567
StringRef getName() final { return "<TensorMemory>"; }
5668
};

include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,4 +174,14 @@ def TritonNvidiaGPURemoveTMEMTokensPass : Pass<"triton-nvidia-gpu-remove-tmem-to
174174
}];
175175
}
176176

177+
def TritonNvidiaGPUCheckMatmulTwoCTAPass : Pass<"triton-nvidia-check-matmul-two-cta", "mlir::ModuleOp"> {
178+
let summary = "Verify consistent two_ctas usage across matmuls";
179+
180+
let description = [{
181+
Inspect all matmul operations and ensure they agree on the `two_ctas`
182+
setting. Propagate the chosen value to the module so later lowering steps
183+
can access it. Compilation fails if mixed configurations are detected.
184+
}];
185+
}
186+
177187
#endif

lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_triton_library(TritonNvidiaGPUTransforms
2+
CheckMatmulTwoCTAs.cpp
23
FenceInsertion.cpp
34
InterleaveTMem.cpp
45
MMALowering.cpp
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
2+
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
3+
4+
#include "mlir/IR/BuiltinAttributes.h"
5+
#include "mlir/IR/BuiltinOps.h"
6+
#include "mlir/IR/Diagnostics.h"
7+
#include "mlir/IR/Visitors.h"
8+
9+
namespace ttng = mlir::triton::nvidia_gpu;
10+
11+
namespace mlir::triton::nvidia_gpu {
12+
13+
#define GEN_PASS_DEF_TRITONNVIDIAGPUCHECKMATMULTWOCTAPASS
14+
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
15+
16+
namespace {
17+
18+
class TritonNvidiaGPUCheckMatmulTwoCTAPass
19+
: public impl::TritonNvidiaGPUCheckMatmulTwoCTAPassBase<
20+
TritonNvidiaGPUCheckMatmulTwoCTAPass> {
21+
public:
22+
using impl::TritonNvidiaGPUCheckMatmulTwoCTAPassBase<
23+
TritonNvidiaGPUCheckMatmulTwoCTAPass>::
24+
TritonNvidiaGPUCheckMatmulTwoCTAPassBase;
25+
26+
void runOnOperation() override {
27+
ModuleOp mod = getOperation();
28+
Operation *firstMatmul = nullptr;
29+
bool firstTwoCTA = false;
30+
31+
WalkResult result = mod.walk([&](ttng::TCGen5MMAOp op) {
32+
bool currentTwoCTA = op.getTwoCtas();
33+
if (!firstMatmul) {
34+
firstMatmul = op;
35+
firstTwoCTA = currentTwoCTA;
36+
return WalkResult::advance();
37+
}
38+
if (currentTwoCTA != firstTwoCTA) {
39+
auto diag = op.emitError()
40+
<< "inconsistent two_ctas setting across matmuls; "
41+
"expected all matmuls to "
42+
<< (firstTwoCTA ? "enable" : "disable") << " two_ctas.";
43+
diag.attachNote(firstMatmul->getLoc())
44+
<< "first matmul here has two_ctas="
45+
<< (firstTwoCTA ? "true" : "false") << ".";
46+
return WalkResult::interrupt();
47+
}
48+
return WalkResult::advance();
49+
});
50+
51+
if (result.wasInterrupted()) {
52+
signalPassFailure();
53+
return;
54+
}
55+
56+
bool twoCTAValue = firstMatmul ? firstTwoCTA : false;
57+
mod->setAttr(AttrTwoCTAsName, BoolAttr::get(mod.getContext(), twoCTAValue));
58+
}
59+
};
60+
61+
} // namespace
62+
63+
} // namespace mlir::triton::nvidia_gpu

test/Conversion/tritongpu_to_llvm_blackwell.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
270270
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 16, CTAsPerCGA = [1, 2], CTASplitNum = [1, 2], CTAOrder = [1, 0]}>
271271
#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [2], CTASplitNum = [1], CTAOrder = [0]}>
272272
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, colStride = 1, CTASplitM = 2>
273-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
273+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttng.two-ctas" = true} {
274274
// CHECK-LABEL: @tc_gen5_mma_2ctas
275275
tt.func @tc_gen5_mma_2ctas(%a: !ttg.memdesc<256x32xf16, #shared, #ttg.shared_memory>,
276276
%b: !ttg.memdesc<32x128xf16, #shared1, #ttg.shared_memory>,

third_party/nvidia/backend/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@ def make_llir(self, src, metadata, options, capability):
348348
passes.gluon.add_inliner(pm)
349349
nvidia.passes.ttgpuir.add_allocate_shared_memory_nv(pm, capability, ptx_version)
350350
nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm)
351+
nvidia.passes.ttnvgpuir.add_check_matmul_two_cta(pm)
351352
if knobs.compilation.instrumentation_mode == "consan":
352353
# Call ConcurrencySanitizerPass here, before allocating global scratch memory but after allocating tensor and shared
353354
passes.ttgpuir.add_concurrency_sanitizer(pm)

third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -595,10 +595,7 @@ static Value initTensorMemory(LLVM::LLVMFuncOp func) {
595595
return LLVM::UndefOp::create(rewriter, loc, ptr_ty(ctx, 6));
596596
}
597597

598-
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod);
599-
// Assume that 2CTAs is used if we have two CTAs this is pessimistic but
600-
// should be fine for now.
601-
bool useTwoCTAs = numCTAs == 2;
598+
bool useTwoCTAs = mlir::triton::nvidia_gpu::getModuleTwoCTAs(mod);
602599
// This code is only executed by the default warp group.
603600
Value threadId = NVVM::ThreadIdXOp::create(rewriter, loc, i32_ty);
604601
Value pred = b.icmp_ult(threadId, b.i32_val(32));

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -264,17 +264,17 @@ static void createScaledGen5MMA(ConversionPatternRewriter &rewriter,
264264
MemDescOperand a, Value b, MemDescOperand d,
265265
Value scaleA, Value scaleB, Value pred,
266266
Value instDescriptor, Value useInitAcc,
267-
bool aInTmem, mxfpKind mxfpInstKind) {
267+
bool aInTmem, mxfpKind mxfpInstKind,
268+
bool twoCTAs) {
268269
PTXBuilder ptxBuilder;
269-
std::string opcode;
270+
std::string opcode =
271+
"tcgen05.mma.cta_group::" + std::to_string(twoCTAs ? 2 : 1) + ".kind::";
270272
if (mxfpInstKind == mxfpKind::mxf8f6f4) {
271-
opcode =
272-
"tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale.scale_vec::1X";
273+
opcode += "mxf8f6f4.block_scale.scale_vec::1X";
273274
} else if (mxfpInstKind == mxfpKind::mxf4) {
274-
opcode = "tcgen05.mma.cta_group::1.kind::mxf4.block_scale.scale_vec::2X";
275+
opcode += "mxf4.block_scale.scale_vec::2X";
275276
} else if (mxfpInstKind == mxfpKind::mxf4nvf4) {
276-
opcode =
277-
"tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X";
277+
opcode += "mxf4nvf4.block_scale.scale_vec::4X";
278278
} else {
279279
assert(0 && "Unsupported mxfp kind.");
280280
}
@@ -312,7 +312,9 @@ static void createMMACommit(ConversionPatternRewriter &rewriter, Location loc,
312312
"tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::"
313313
"cluster.multicast::cluster.b64 [$1], $2;";
314314
} else {
315-
opcode = "@$0 tcgen05.commit.cta_group::1.mbarrier::arrive::one.b64 [$1];";
315+
opcode =
316+
"@$0 tcgen05.commit.cta_group::" + std::to_string(twoCTAs ? 2 : 1) +
317+
".mbarrier::arrive::one.b64 [$1];";
316318
}
317319
auto &barrierOp = *ptxBuilder.create(opcode);
318320
barrierOp(ptxOperands, /*onlyAttachMLIRArgs=*/true);
@@ -486,7 +488,8 @@ void convertDot(const LLVMTypeConverter &typeConverter,
486488
MemDescType bTensorTy = op.getB().getType();
487489
MemDescType dTensorTy = op.getD().getType();
488490
auto dLayout = cast<ttng::TensorMemoryEncodingAttr>(dTensorTy.getEncoding());
489-
bool twoCTAs = op.getTwoCtas();
491+
bool twoCTAs = ttng::getModuleTwoCTAs(op);
492+
assert(twoCTAs == op.getTwoCtas());
490493

491494
DotConversion dot;
492495

@@ -595,6 +598,7 @@ void convertScaledDot(const LLVMTypeConverter &typeConverter,
595598
Value baseD = tb.ptrtoint(i32_ty, adaptor.getD());
596599
Value baseScaleA = tb.ptrtoint(i32_ty, adaptor.getAScale());
597600
Value baseScaleB = tb.ptrtoint(i32_ty, adaptor.getBScale());
601+
bool twoCTAs = ttng::getModuleTwoCTAs(op);
598602

599603
int numRows = 128;
600604
int colSizeInBits = 32;
@@ -634,14 +638,13 @@ void convertScaledDot(const LLVMTypeConverter &typeConverter,
634638
subWordIdx, subWordIdx, mxfpInstKind);
635639
createScaledGen5MMA(rewriter, loc, op, a, b, accAddress, scaleA, scaleB,
636640
pred, instDescriptor, useInitAcc, desc.aInTmem,
637-
mxfpInstKind);
641+
mxfpInstKind, twoCTAs);
638642
};
639643

640644
convertDotImpl(typeConverter, rewriter, loc, op.getA(), op.getB(),
641645
adaptor.getA(), adaptor.getB(), dTensorTy, adaptor.getUseD(),
642646
adaptor.getPred(), adaptor.getBarriers(),
643-
adaptor.getBarrierPreds(), /*twoCTAs=*/false, opKindIsMXFP4,
644-
dot);
647+
adaptor.getBarrierPreds(), twoCTAs, opKindIsMXFP4, dot);
645648
}
646649

647650
//===----------------------------------------------------------------------===//
@@ -699,7 +702,7 @@ struct TCGen5CommitOpConversion
699702
pred = b.and_(adaptor.getPred(), pred);
700703

701704
createMMACommit(rewriter, op.getLoc(), smemObj.getBase(), pred,
702-
op.getTwoCtas());
705+
ttng::getModuleTwoCTAs(op));
703706
rewriter.eraseOp(op);
704707
return success();
705708
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -524,18 +524,20 @@ struct TensorMemoryAllocOpConversion
524524
};
525525

526526
static void createCommit(ConversionPatternRewriter &rewriter, Location loc,
527-
Value barrier, Value pred) {
527+
Value barrier, Value pred, bool twoCTAs) {
528528
PTXBuilder ptxBuilder;
529529
auto *barrierOperand = ptxBuilder.newAddrOperand(barrier, "r");
530-
std::string opcode = "tcgen05.commit.cta_group::1.mbarrier::arrive::one.b64";
530+
std::string opcode =
531+
"tcgen05.commit.cta_group::" + std::to_string(twoCTAs ? 2 : 1) +
532+
".mbarrier::arrive::one.b64";
531533
auto &barrierOp = *ptxBuilder.create(opcode);
532534
barrierOp(barrierOperand).predicate(pred);
533535
ptxBuilder.launch(rewriter, loc, void_ty(rewriter.getContext()));
534536
}
535537

536538
static void createTcgen05Cp(ConversionPatternRewriter &rewriter, Location loc,
537539
Value tmem_address, Value src_desc, Value pred,
538-
TMemCopyAtom atom) {
540+
TMemCopyAtom atom, bool twoCTAs) {
539541
PTXBuilder ptxBuilder;
540542
auto dst = ptxBuilder.newAddrOperand(tmem_address, "r");
541543
auto src = ptxBuilder.newOperand(src_desc, "l");
@@ -547,9 +549,9 @@ static void createTcgen05Cp(ConversionPatternRewriter &rewriter, Location loc,
547549
} else if (atom.multicast == 3) {
548550
warp = ".warpx4";
549551
}
550-
std::string opcode = "tcgen05.cp.cta_group::1" + warp + "." +
551-
std::to_string(atom.nRow) + "x" +
552-
std::to_string(atom.bCol) + "b";
552+
std::string opcode =
553+
"tcgen05.cp.cta_group::" + std::to_string(twoCTAs ? 2 : 1) + warp + "." +
554+
std::to_string(atom.nRow) + "x" + std::to_string(atom.bCol) + "b";
553555
auto &op = *ptxBuilder.create(opcode);
554556
op({dst, src}).predicate(pred);
555557
ptxBuilder.launch(rewriter, loc, void_ty(rewriter.getContext()));
@@ -592,6 +594,7 @@ static void copySharedToTmem(ConversionPatternRewriter &rewriter, Location loc,
592594
auto loader = DotOpMmaSmemLoader::build(loc, rewriter, cvtWarp, bitwidth,
593595
smemBase, instrShape, 0, 5);
594596
assert(!loader.getDescriptor().transposed);
597+
bool twoCTAs = getModuleTwoCTAs(op);
595598
// Check correct lbo/sbo along the multicast
596599
auto strideRow = cvt.getBasis(kRow, llvm::Log2_32(8), kOffset);
597600
if ((atom.multicast & 1) == 0) {
@@ -608,7 +611,7 @@ static void copySharedToTmem(ConversionPatternRewriter &rewriter, Location loc,
608611
auto tmemAddr =
609612
b.or_(b.ptrtoint(i32_ty, baseDst), b.i32_val(col * bitwidth / 32),
610613
/*disjoint=*/true);
611-
createTcgen05Cp(rewriter, loc, tmemAddr, desc, pred, atom);
614+
createTcgen05Cp(rewriter, loc, tmemAddr, desc, pred, atom, twoCTAs);
612615
}
613616
}
614617

@@ -622,13 +625,14 @@ struct TensorMemoryCopyOpConversion
622625
assert(lookupNumCTAs(rewriter) == 1 && "NYI");
623626
Location loc = op->getLoc();
624627
Value pred = LLVM::NVIDIA::createElectPredicateWarp0(loc, rewriter);
628+
bool twoCTAs = getModuleTwoCTAs(op);
625629
copySharedToTmem(rewriter, loc, typeConverter, op, adaptor.getSrc(),
626630
adaptor.getDst(), pred);
627631

628632
if (op.getBarrier()) {
629633
auto barrier = LLVM::getSharedMemoryObjectFromStruct(
630634
op.getLoc(), adaptor.getBarrier(), i64_ty, rewriter);
631-
createCommit(rewriter, loc, barrier.getBase(), pred);
635+
createCommit(rewriter, loc, barrier.getBase(), pred, twoCTAs);
632636
}
633637

634638
rewriter.eraseOp(op);

third_party/nvidia/triton_nvidia.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ void init_triton_nvidia_passes_ttnvgpuir(py::module &&m) {
6060
ttng::createTritonNvidiaGPUPromoteLHSToTMemPass);
6161
ADD_PASS_WRAPPER_0("add_remove_tmem_tokens",
6262
ttng::createTritonNvidiaGPURemoveTMEMTokensPass);
63+
ADD_PASS_WRAPPER_0("add_check_matmul_two_cta",
64+
ttng::createTritonNvidiaGPUCheckMatmulTwoCTAPass);
6365
ADD_PASS_WRAPPER_0("add_nvgpu_to_llvm",
6466
mlir::triton::createConvertNVGPUToLLVM);
6567
ADD_PASS_WRAPPER_0("add_warp_specialize_to_llvm",

0 commit comments

Comments
 (0)