Skip to content

Commit d8774e3

Browse files
authored
[TritonNvidiaGPU] Tighten WGMMA verifiers (#7708)
B must be in shared memory. Check the operand encodings as well.
1 parent 98d23fe commit d8774e3

File tree

5 files changed

+14
-8
lines changed

5 files changed

+14
-8
lines changed

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [
8787

8888
let arguments = (ins
8989
TTG_TensorOrMemDesc:$a,
90-
TTG_TensorOrMemDesc:$b,
90+
TTG_MemDescType:$b,
9191
TT_FpIntTensor:$c,
9292
Optional<I1>:$useC,
9393
DefaultValuedAttr<TT_InputPrecisionAttr, "::mlir::triton::InputPrecision::IEEE">:$inputPrecision,
@@ -99,7 +99,7 @@ def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [
9999

100100
let assemblyFormat = [{
101101
$a`,` $b`,` $c (`,` $useC^)? attr-dict
102-
`:` type($a) `*` type($b) `->` type($d)
102+
`:` type($a) `*` qualified(type($b)) `->` type($d)
103103
}];
104104

105105
let extraClassDeclaration = [{

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,17 @@ LogicalResult WarpGroupDotOp::verify() {
6363
auto nvmmaEnc = dyn_cast<NvidiaMmaEncodingAttr>(resTy.getEncoding());
6464
if (!nvmmaEnc || !nvmmaEnc.isHopper())
6565
return emitOpError("WGMMA result layout must be Hopper NVMMA");
66+
67+
if (!isa<NVMMASharedEncodingAttr, DotOperandEncodingAttr>(
68+
getA().getType().getEncoding()))
69+
return emitOpError("WGMMA A operand must have NVMMA shared or dot layout");
70+
if (!isa<NVMMASharedEncodingAttr>(getB().getType().getEncoding()))
71+
return emitOpError("WGMMA B operand must have NVMMA shared layout");
72+
6673
auto numWarps = gpu::lookupNumWarps(getOperation());
6774
if (numWarps % 4)
6875
return emitOpError("WGMMA requires num_warps to be divisible by 4");
76+
6977
auto retShapePerCTA = getShapePerCTA(resTy);
7078
int rank = retShapePerCTA.size();
7179
if (rank != 2)
@@ -74,12 +82,14 @@ LogicalResult WarpGroupDotOp::verify() {
7482
return emitOpError("WGMMA result M dimension must be divisible by 64");
7583
if (retShapePerCTA[1] % 8 != 0)
7684
return emitOpError("WGMMA result N dimension must be divisible by 8");
85+
7786
auto aElemTy = getA().getType().getElementType();
7887
if (!(llvm::isa<Float8E5M2Type, Float8E4M3FNType>(aElemTy) ||
7988
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
8089
aElemTy.isF32()))
8190
return emitOpError("WGMMA result element type must be F16, BF16, F32, "
8291
"F8E5M2, F8E4M3FN, or integer type");
92+
8393
if (getMaxNumImpreciseAcc() < 32 &&
8494
(llvm::isa<Float8E5M2Type, Float8E4M3FNType>(aElemTy)) &&
8595
resTy.getElementType().isF32()) {

python/triton/experimental/gluon/language/nvidia/hopper/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def warpgroup_mma(a, b, acc, *, use_acc=True, precision=None, max_num_imprecise_
2626
2727
Args:
2828
a (tensor or shared_memory_descriptor): Left hand side operand.
29-
b (tensor or shared_memory_descriptor): Right hand side operand.
29+
b (shared_memory_descriptor): Right hand side operand.
3030
acc (tensor): Accumulator tensor.
3131
use_acc (bool): Whether to use the initial value of the accumulator. Defaults to True.
3232
precision (str, optional): Dot input precision. Defaults to builder default.

third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSDataPartition.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ static bool getBackwardSliceToPartition(Value v,
279279
if (!getBackwardSliceToPartition(operand, partitionScheme, currentDim))
280280
return false;
281281
} else if (auto dotOp = dyn_cast<nvidia_gpu::WarpGroupDotOp>(op)) {
282-
if (!getBackwardSliceToPartition(currentDim == 0 ? dotOp.getA()
282+
if (!getBackwardSliceToPartition(currentDim == 0 ? Value(dotOp.getA())
283283
: dotOp.getB(),
284284
partitionScheme, currentDim))
285285
return false;

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -497,10 +497,6 @@ LogicalResult convertWGMMA(triton::nvidia_gpu::WarpGroupDotOp op,
497497
ConversionPatternRewriter &rewriter, Value thread) {
498498
auto AEnc = op.getA().getType().getEncoding();
499499
auto BEnc = op.getB().getType().getEncoding();
500-
assert(mlir::isa<NVMMASharedEncodingAttr>(AEnc) ||
501-
mlir::isa<DotOperandEncodingAttr>(AEnc));
502-
assert(mlir::isa<NVMMASharedEncodingAttr>(BEnc) &&
503-
"Operand B should use Shared layout.");
504500
return convertDot(typeConverter, rewriter, op.getLoc(), op.getOperation(), //
505501
op.getA(), op.getB(), op.getC(), op.getD(), op.getUseC(), //
506502
adaptor.getA(), adaptor.getB(), adaptor.getC(), //

0 commit comments

Comments
 (0)