Skip to content

Commit 26d7722

Browse files
authored
[TritonNvidiaGPU] Tighten WGMMA verifier; improve FenceInsertion (#6801)
* verify that WarpGroupDotOp's result encoding is always NVMMA Hopper encoding * clean up some code with this * teach FenceInsertion to look through WarpSpecializeOp * deduplicate fences (e.g. two dots in a loop with captured reg->shared operands)
1 parent 6ae57f9 commit 26d7722

File tree

5 files changed

+132
-72
lines changed

5 files changed

+132
-72
lines changed

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -72,32 +72,40 @@ def TTNG_ClusterWaitOp : TTNG_Op<"cluster_wait", []> {
7272
//
7373
// WarpGroupDot Op
7474
//
75-
def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [DeclareOpInterfaceMethods<InferTypeOpInterface>,
76-
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
77-
DeclareOpInterfaceMethods<DotOpInterface>,
78-
TypesMatchWith<"result's type matches accumulator's type",
79-
"d", "c", "$_self">]> {
80-
let summary = "warp group dot";
81-
82-
let description = [{
83-
$d = matrix_multiply($a, $b) + $c. For docs on InputPrecisionAttr, see TT_DotOp
84-
}];
85-
86-
let arguments = (ins TTG_TensorOrMemDesc:$a,
87-
TTG_TensorOrMemDesc:$b,
88-
TT_FpIntTensor:$c,
89-
Optional<I1>:$useC,
90-
DefaultValuedAttr<TT_InputPrecisionAttr, "::mlir::triton::InputPrecision::IEEE">:$inputPrecision,
91-
DefaultValuedAttr<I32Attr, "0">:$maxNumImpreciseAcc,
92-
DefaultValuedAttr<BoolAttr, "false">:$isAsync);
93-
94-
let results = (outs TT_FpIntTensor:$d);
95-
96-
let assemblyFormat = "$a`,` $b`,` $c (`,` $useC^)? attr-dict `:` type($a) `*` type($b) `->` type($d)";
97-
98-
let extraClassDeclaration = [{
99-
bool needsPartialAccumulator();
100-
}];
75+
def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [
76+
DeclareOpInterfaceMethods<InferTypeOpInterface>,
77+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
78+
DeclareOpInterfaceMethods<DotOpInterface>,
79+
TypesMatchWith<"result's type matches accumulator's type", "d", "c", "$_self">
80+
]> {
81+
let summary = "warp group dot";
82+
83+
let description = [{
84+
$d = matrix_multiply($a, $b) + $c. For docs on InputPrecisionAttr, see TT_DotOp
85+
}];
86+
87+
let arguments = (ins
88+
TTG_TensorOrMemDesc:$a,
89+
TTG_TensorOrMemDesc:$b,
90+
TT_FpIntTensor:$c,
91+
Optional<I1>:$useC,
92+
DefaultValuedAttr<TT_InputPrecisionAttr, "::mlir::triton::InputPrecision::IEEE">:$inputPrecision,
93+
DefaultValuedAttr<I32Attr, "0">:$maxNumImpreciseAcc,
94+
DefaultValuedAttr<BoolAttr, "false">:$isAsync
95+
);
96+
97+
let results = (outs TT_FpIntTensor:$d);
98+
99+
let assemblyFormat = [{
100+
$a`,` $b`,` $c (`,` $useC^)? attr-dict
101+
`:` type($a) `*` type($b) `->` type($d)
102+
}];
103+
104+
let extraClassDeclaration = [{
105+
bool needsPartialAccumulator();
106+
}];
107+
108+
let hasVerifier = 1;
101109
}
102110

103111
def TTNG_WarpGroupDotWaitOp : TTNG_Op<"warp_group_dot_wait", [DeclareOpInterfaceMethods<InferTypeOpInterface>,

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ namespace triton {
3434
namespace nvidia_gpu {
3535

3636
// -- WarpGroupDotOp --
37-
mlir::LogicalResult WarpGroupDotOp::inferReturnTypes(
37+
LogicalResult WarpGroupDotOp::inferReturnTypes(
3838
MLIRContext *context, std::optional<Location> location, ValueRange operands,
3939
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
4040
SmallVectorImpl<Type> &inferredReturnTypes) {
@@ -43,21 +43,27 @@ mlir::LogicalResult WarpGroupDotOp::inferReturnTypes(
4343
inferredReturnTypes.push_back(accTy);
4444

4545
// verify encodings
46-
auto aEnc =
47-
cast<triton::gpu::TensorOrMemDesc>(operands[0].getType()).getEncoding();
48-
auto bEnc =
49-
cast<triton::gpu::TensorOrMemDesc>(operands[1].getType()).getEncoding();
46+
auto aEnc = cast<TensorOrMemDesc>(operands[0].getType()).getEncoding();
47+
auto bEnc = cast<TensorOrMemDesc>(operands[1].getType()).getEncoding();
5048
auto retEnc = accTy.getEncoding();
5149
if (aEnc) {
5250
assert(bEnc);
5351
Dialect &dialect = aEnc.getDialect();
5452
auto interface = cast<DialectInferLayoutInterface>(&dialect);
5553
if (interface->inferDotOpEncoding(aEnc, 0, retEnc, location).failed())
56-
return mlir::failure();
54+
return failure();
5755
if (interface->inferDotOpEncoding(bEnc, 1, retEnc, location).failed())
58-
return mlir::failure();
56+
return failure();
5957
}
60-
return mlir::success();
58+
return success();
59+
}
60+
61+
LogicalResult WarpGroupDotOp::verify() {
62+
auto nvmmaEnc =
63+
dyn_cast<NvidiaMmaEncodingAttr>(getD().getType().getEncoding());
64+
if (!nvmmaEnc || !nvmmaEnc.isHopper())
65+
return emitOpError("WGMMA result layout must be Hopper NVMMA");
66+
return success();
6167
}
6268

6369
void WarpGroupDotOp::getEffects(

lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -39,25 +39,17 @@ struct FenceInsertionPass
3939
if (computeCapability < 90)
4040
return;
4141
ModuleOp mod = getOperation();
42-
mod.walk([&](Operation *op) {
43-
bool isMMAv3 = isa<ttng::WarpGroupDotOp>(op);
44-
if (!isMMAv3 && !isa<ttng::MMAv5OpInterface>(op))
45-
return WalkResult::advance();
46-
OpBuilder builder(op);
47-
auto a = op->getOperand(0);
48-
auto b = op->getOperand(1);
49-
if (isMMAv3) {
50-
auto mmaEncoding = dyn_cast<ttg::NvidiaMmaEncodingAttr>(
51-
cast<RankedTensorType>(op->getResult(0).getType()).getEncoding());
52-
if (!mmaEncoding || !mmaEncoding.isHopper())
53-
return WalkResult::advance();
54-
}
42+
mod.walk([&](tt::DotOpInterface dotOp) {
43+
Value a = dotOp.getA();
44+
Value b = dotOp.getB();
5545
bool aDependsOnShared = dependOnCopyRegToShared(a);
5646
bool bDependsOnShared = dependOnCopyRegToShared(b);
5747
if (!aDependsOnShared && !bDependsOnShared)
5848
return WalkResult::advance();
59-
Operation *fence = builder.create<ttng::FenceAsyncSharedOp>(
60-
op->getLoc(), /*bCluster=*/false);
49+
50+
OpBuilder builder(dotOp);
51+
auto fence = builder.create<ttng::FenceAsyncSharedOp>(dotOp.getLoc(),
52+
/*bCluster=*/false);
6153
// If there is all the dependencies are outside of the loop try to hoist
6254
// the fence.
6355
while (auto loopOp = fence->getParentOfType<LoopLikeOpInterface>()) {
@@ -69,6 +61,14 @@ struct FenceInsertionPass
6961
break;
7062
loopOp.moveOutOfLoop(fence);
7163
}
64+
65+
// If the previous op is already a fence, this one isn't needed.
66+
if (auto lastFence = dyn_cast_or_null<ttng::FenceAsyncSharedOp>(
67+
fence->getPrevNode())) {
68+
if (lastFence.getBCluster() == fence.getBCluster())
69+
fence.erase();
70+
}
71+
7272
return WalkResult::advance();
7373
});
7474
}
@@ -88,6 +88,7 @@ struct FenceInsertionPass
8888
visited.insert(operand);
8989
if (!isa<triton::gpu::MemDescType>(operand.getType()))
9090
return false;
91+
9192
auto op = operand.getDefiningOp();
9293
if (op) {
9394
// reach an alloc copying from register, we need a fence.
@@ -100,26 +101,30 @@ struct FenceInsertionPass
100101
}
101102
return false;
102103
}
104+
103105
// reach BlockArgument
104106
BlockArgument arg = cast<BlockArgument>(operand);
105107
unsigned argNum = arg.getArgNumber();
106108
Operation *argOwner = arg.getOwner()->getParentOp();
107-
// support ForOp only
109+
// look through ForOp iter argument
108110
if (auto forOp = dyn_cast<scf::ForOp>(argOwner)) {
111+
assert(argNum != 0 && "induction var cannot be memdesc type");
112+
--argNum;
109113
// prologue
110-
auto iterOperands = forOp.getInitArgs();
111-
if (argNum == 0)
112-
return false;
113-
if (dependOnCopyRegToShared(iterOperands[argNum - 1], visited))
114+
if (dependOnCopyRegToShared(forOp.getInitArgs()[argNum], visited))
114115
return true;
115116
// yield
116117
auto yieldOp = forOp.getBody()->getTerminator();
117-
Value v = yieldOp->getOperand(argNum - 1);
118-
auto entry = std::make_pair<Operation *, unsigned>(std::move(yieldOp),
119-
std::move(argNum));
120-
if (dependOnCopyRegToShared(v, visited))
121-
return true;
118+
Value v = yieldOp->getOperand(argNum);
119+
return dependOnCopyRegToShared(v, visited);
122120
}
121+
122+
// look through `ttg.warp_specialize`.
123+
if (auto wsOp = dyn_cast<ttg::WarpSpecializePartitionsOp>(argOwner)) {
124+
return dependOnCopyRegToShared(
125+
wsOp.getParentOp().getExplicitCaptures()[argNum]);
126+
}
127+
123128
// Conservatively return true for other ops
124129
return true;
125130
}

test/TritonGPU/fence-inserstion.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,51 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
9191
tt.return
9292
}
9393
}
94+
95+
// -----
96+
97+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
98+
#smem = #ttg.shared_memory
99+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
100+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 64, unpacked = true>
101+
102+
module attributes {ttg.target = "cuda:100", "ttg.num-warps" = 4 : i32} {
103+
104+
// CHECK-LABEL: @mma_inside_warp_specialize
105+
tt.func @mma_inside_warp_specialize(%src: tensor<64x64xf16, #blocked>) {
106+
%A = ttg.local_alloc %src : (tensor<64x64xf16, #blocked>) -> !ttg.memdesc<64x64xf16, #shared, #smem>
107+
%B = ttg.local_alloc : () -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
108+
%D = ttng.tmem_alloc : () -> !ttg.memdesc<64x64xf32, #tmem, #ttng.tensor_memory, mutable>
109+
110+
ttg.warp_specialize(%A, %B, %D)
111+
default {
112+
ttg.warp_yield
113+
}
114+
// CHECK: partition0
115+
partition0(%lhs: !ttg.memdesc<64x64xf16, #shared, #smem>, %rhs: !ttg.memdesc<64x64xf16, #shared, #smem, mutable>, %acc: !ttg.memdesc<64x64xf32, #tmem, #ttng.tensor_memory, mutable>) num_warps(4) {
116+
%true = arith.constant true
117+
%c0_i32 = arith.constant 0 : i32
118+
%c1_i32 = arith.constant 1 : i32
119+
%c32_i32 = arith.constant 32 : i32
120+
// CHECK: ttng.fence_async_shared
121+
// CHECK-NEXT: scf.for
122+
scf.for %i = %c0_i32 to %c32_i32 step %c1_i32 : i32 {
123+
// CHECK-NEXT: ttng.tc_gen5_mma
124+
ttng.tc_gen5_mma %lhs, %rhs, %acc, %true, %true : !ttg.memdesc<64x64xf16, #shared, #smem>, !ttg.memdesc<64x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x64xf32, #tmem, #ttng.tensor_memory, mutable>
125+
// CHECK-NEXT: ttng.tc_gen5_mma
126+
ttng.tc_gen5_mma %lhs, %rhs, %acc, %true, %true : !ttg.memdesc<64x64xf16, #shared, #smem>, !ttg.memdesc<64x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x64xf32, #tmem, #ttng.tensor_memory, mutable>
127+
}
128+
ttg.warp_return
129+
}
130+
// CHECK: partition1
131+
partition1(%lhs: !ttg.memdesc<64x64xf16, #shared, #smem>, %rhs: !ttg.memdesc<64x64xf16, #shared, #smem, mutable>, %acc: !ttg.memdesc<64x64xf32, #tmem, #ttng.tensor_memory, mutable>) num_warps(4) {
132+
// CHECK-NOT: ttng.fence_async_shared
133+
%true = arith.constant true
134+
// CHECK: ttng.tc_gen5_mma
135+
ttng.tc_gen5_mma %rhs, %rhs, %acc, %true, %true : !ttg.memdesc<64x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x64xf32, #tmem, #ttng.tensor_memory, mutable>
136+
ttg.warp_return
137+
} : (!ttg.memdesc<64x64xf16, #shared, #smem>, !ttg.memdesc<64x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x64xf32, #tmem, #ttng.tensor_memory, mutable>) -> ()
138+
tt.return
139+
}
140+
141+
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,28 +71,21 @@ struct WarpGroupDotOpConversion
7171
auto loc = op.getLoc();
7272
// D = A * B + C
7373
Value A = op.getA();
74-
Value D = op.getResult();
74+
TypedValue<RankedTensorType> D = op.getResult();
7575

7676
// Here we assume the DotOp's operands always comes from shared memory.
7777
auto AShapePerCTA = getShapePerCTA(A.getType());
7878
size_t reduceAxis = 1;
7979
unsigned K = AShapePerCTA[reduceAxis];
8080
bool isOuter = K == 1;
8181

82-
NvidiaMmaEncodingAttr mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(
83-
cast<RankedTensorType>(D.getType()).getEncoding());
84-
if (!isOuter && mmaLayout &&
85-
supportMMA(op.getOperand(0), mmaLayout.getVersionMajor())) {
86-
if (mmaLayout.isHopper()) {
87-
return convertWGMMA(op, adaptor, getTypeConverter(), rewriter,
88-
getThreadId(rewriter, loc));
89-
}
90-
91-
llvm::report_fatal_error(
92-
"Unsupported MMA kind found when converting WarpGroupDotOp to LLVM.");
82+
auto mmaLayout = cast<NvidiaMmaEncodingAttr>(D.getType().getEncoding());
83+
if (!isOuter && supportMMA(op.getOperand(0), mmaLayout.getVersionMajor())) {
84+
return convertWGMMA(op, adaptor, getTypeConverter(), rewriter,
85+
getThreadId(rewriter, loc));
9386
}
9487

95-
llvm::report_fatal_error(
88+
return op.emitError(
9689
"Unsupported WarpGroupDotOp found when converting TritonGPU to LLVM.");
9790
}
9891
};

0 commit comments

Comments
 (0)