Skip to content

Commit a1a6b70

Browse files
SamGinzburgmakslevental
authored andcommitted
[mlir][dialect] Refactor DotLike trait into a DotOpInterface + Enable verification of scaled_dot (triton-lang#5796)
# Overview The Triton MLIR dialect presently has a `DotLike` trait which DotOps have. The problem with this trait is the way it is currently implemented prevents `scaled_dot` from being verified properly (dimensions are not properly checked at the moment: "TODO: enable back with an interface to support scaled dot."). This PR refactors the "DotLike" trait into an interface which implements a "verifyDims" function that checks if the dims for the A and B operands are compatible (e.g., something like MxK1 and K2xN; k1==k2; in the simple case). # How The initial implementation of DotOpInterface is similar to the prior `DotLike` trait with the exception that it includes the `verifyDims` function which all DotOps must implement---this function just checks whether the dimensions of the A, B inputs match. In the future this interface can be extended to include more functionality. # Testing I think since this enables the verifier for `scaled_dot`, that the existing scaled dot tests should cover any changes made in the PR---but if this is wrong I will add additional tests. ================================================================ <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because `I think this should be covered by existing tests`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 1b167d1 commit a1a6b70

File tree

17 files changed

+191
-71
lines changed

17 files changed

+191
-71
lines changed

include/triton/Dialect/Triton/IR/OpInterfaces.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ namespace impl {
1111

1212
LogicalResult verifyTransposeOpInterface(Operation *op);
1313

14+
LogicalResult verifyDotOpInterface(Operation *op);
15+
1416
} // namespace impl
1517

1618
} // namespace triton

include/triton/Dialect/Triton/IR/Traits.h

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -58,53 +58,6 @@ class VerifyTensorLayoutsTrait
5858
}
5959
};
6060

61-
// Verify if the op is a dot-like operation.
62-
// A dot-like operation should have three operands.
63-
// The first two operands should share a common dimension, and the result
64-
// should have the dimensions of the two operands that are not shared.
65-
// A dot-like operation can be either 2d or 3d.
66-
// In the 3d case, the first dimension of operands is the batch dimension.
67-
template <class ConcreteType>
68-
class DotLike : public TraitBase<ConcreteType, DotLike> {
69-
public:
70-
static LogicalResult verifyTrait(Operation *op) {
71-
if (op->getNumOperands() < 3)
72-
return op->emitOpError("expected at least 3 operands");
73-
auto aTy = cast<ShapedType>(op->getOperand(0).getType());
74-
auto bTy = cast<ShapedType>(op->getOperand(1).getType());
75-
auto cTy = cast<ShapedType>(op->getOperand(2).getType());
76-
auto aShape = aTy.getShape();
77-
auto bShape = bTy.getShape();
78-
auto cShape = cTy.getShape();
79-
// Check if all 3d or all 2d
80-
if (aShape.size() != 2 && aShape.size() != 3)
81-
return op->emitOpError("expected operands to be 2d or 3d");
82-
if (aShape.size() != bShape.size() || aShape.size() != cShape.size())
83-
return op->emitOpError("expected all operands to have the same rank");
84-
// Check if the first two operands share a common dimension
85-
// TODO: enable back with an interface to support scaled dot.
86-
// if (aShape[aShape.size() - 1] != bShape[aShape.size() - 2])
87-
// return op->emitOpError("expected the last dimension of the first
88-
// operand "
89-
// "to be equal to the second-to-last dimension of
90-
// " "the second operand");
91-
// Check the batch dimension
92-
if (aShape.size() == 3 &&
93-
(aShape[0] != cShape[0] || bShape[0] != cShape[0]))
94-
return op->emitOpError("expected the first dimension of the first "
95-
"operand to be equal to the first dimension of "
96-
"the result");
97-
// Check the output shape
98-
if (cShape[cShape.size() - 2] != aShape[aShape.size() - 2] ||
99-
cShape[cShape.size() - 1] != bShape[aShape.size() - 1])
100-
return op->emitOpError(
101-
"expected the output shape to be the concatenation of the last "
102-
"dimension of the first operand and the last dimension of the "
103-
"second ");
104-
return success();
105-
}
106-
};
107-
10861
template <typename ConcreteType>
10962
class SameOperandsAndResultEncoding
11063
: public TraitBase<ConcreteType, SameOperandsAndResultEncoding> {

include/triton/Dialect/Triton/IR/TritonInterfaces.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
66

77
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
88
def VerifyTensorLayoutsTrait : NativeOpTrait<"VerifyTensorLayoutsTrait">;
9-
def DotLike : NativeOpTrait<"DotLike">;
109
def SameOperandsEncoding : NativeOpTrait<"SameOperandsEncoding">;
1110
def SameOperandsAndResultEncoding : NativeOpTrait<"SameOperandsAndResultEncoding">;
1211
def SameLoadStoreOperandsShape : NativeOpTrait<"SameLoadStoreOperandsShape">;

include/triton/Dialect/Triton/IR/TritonOpInterfaces.td

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,27 @@ def TransposeOpInterface : OpInterface<"TransposeOpInterface"> {
2929
/*args=*/(ins)>
3030
];
3131

32-
let verify = [{ return ::mlir::triton::impl::verifyTransposeOpInterface($_op); }];
32+
let verify = [{ return ::mlir::triton::impl::verifyTransposeOpInterface($_op); }];
33+
}
34+
35+
def DotOpInterface : OpInterface<"DotOpInterface"> {
36+
let description = [{
37+
This interface is implemented by operations that perform a dot product.
38+
}];
39+
40+
let cppNamespace = "::mlir::triton";
41+
42+
let methods = [
43+
InterfaceMethod<
44+
/*desc=*/[{
45+
Verifies the dimensions of the A and B DotOp operands.
46+
}],
47+
/*retType=*/"bool",
48+
/*methodName=*/"verifyDims",
49+
/*args=*/(ins)>
50+
];
51+
52+
let verify = [{ return ::mlir::triton::impl::verifyDotOpInterface($_op); }];
3353
}
3454

3555

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [Pure]> {
631631
//
632632
def TT_DotOp : TT_Op<"dot", [Pure,
633633
DeclareOpInterfaceMethods<InferTypeOpInterface>,
634-
DotLike,
634+
DeclareOpInterfaceMethods<DotOpInterface>,
635635
TypesMatchWith<"result's type matches accumulator's type",
636636
"d", "c", "$_self">]> {
637637
let summary = "dot";
@@ -671,7 +671,7 @@ def TT_DotOp : TT_Op<"dot", [Pure,
671671
//
672672
def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
673673
AttrSizedOperandSegments,
674-
DotLike,
674+
DeclareOpInterfaceMethods<DotOpInterface>,
675675
TypesMatchWith<"result's type matches accumulator's type",
676676
"d", "c", "$_self">]> {
677677
let summary = "dot_scaled";

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ include "mlir/Dialect/Arith/IR/ArithBase.td"
2727
include "triton/Dialect/Triton/IR/TritonTypes.td"
2828
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
2929
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
30+
include "triton/Dialect/Triton/IR/TritonOpInterfaces.td"
3031
include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td"
3132
include "triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td"
3233
include "mlir/IR/OpBase.td"
@@ -71,7 +72,7 @@ def TTNG_ClusterWaitOp : TTNG_Op<"cluster_wait", []> {
7172
//
7273
def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [DeclareOpInterfaceMethods<InferTypeOpInterface>,
7374
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
74-
DotLike,
75+
DeclareOpInterfaceMethods<DotOpInterface>,
7576
TypesMatchWith<"result's type matches accumulator's type",
7677
"d", "c", "$_self">]> {
7778
let summary = "warp group dot";
@@ -325,7 +326,7 @@ def TTNG_TMAStoreWaitOp : TTNG_Op<"async_tma_store_wait"> {
325326
let assemblyFormat = "attr-dict";
326327
}
327328

328-
def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DotLike]> {
329+
def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DeclareOpInterfaceMethods<DotOpInterface>]> {
329330
let summary = "block level op mapping to tensorcore gen5 mma";
330331

331332
let description = [{
@@ -343,11 +344,12 @@ def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [DeclareOpInterfaceMethods<MemoryE
343344
I1:$pred,
344345
Optional<TTG_MemDescType>:$barrier,
345346
OptionalAttr<UnitAttr>:$two_ctas);
347+
346348
// TODO: improve printing format.
347349
let assemblyFormat = "$a`,` $b`,` $d`,` $useD`,` $pred (`,` $barrier^)? attr-dict `:` functional-type(operands, results)";
348350
}
349351

350-
def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DotLike]> {
352+
def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DeclareOpInterfaceMethods<DotOpInterface>]> {
351353
let summary = "block level op mapping to tensorcore gen5 mma";
352354

353355
let description = [{
@@ -366,6 +368,7 @@ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [DeclareOpInterfaceMe
366368
I1:$useD,
367369
I1:$pred,
368370
Optional<TTG_MemDescType>:$barrier);
371+
369372
// TODO: improve printing format.
370373
let assemblyFormat = "$a `,` $b `,` $d `,` $a_scale `,` $b_scale `,` $useD`,` $pred `lhs` `=` $a_type `rhs` `=` $b_type (`,` $barrier^)? attr-dict `:` functional-type(operands, results)";
371374
}

lib/Dialect/Triton/IR/OpInterfaces.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,50 @@ LogicalResult verifyTransposeOpInterface(Operation *op) {
2929
return success();
3030
}
3131

32+
// A DotOpInterface operation should have at least three operands.
33+
// The first two operands should share a common dimension, and the result
34+
// should have the dimensions of the two operands that are not shared.
35+
// A DotOpInterface operation can be either 2d or 3d.
36+
// In the 3d case, the first dimension of operands is the batch dimension.
37+
LogicalResult verifyDotOpInterface(Operation *op) {
38+
DotOpInterface dotOp = cast<mlir::triton::DotOpInterface>(op);
39+
40+
if (dotOp->getNumOperands() < 3)
41+
return dotOp->emitOpError("expected at least 3 operands");
42+
auto aTy = cast<ShapedType>(dotOp->getOperand(0).getType());
43+
auto bTy = cast<ShapedType>(dotOp->getOperand(1).getType());
44+
auto cTy = cast<ShapedType>(dotOp->getOperand(2).getType());
45+
auto aShape = aTy.getShape();
46+
auto bShape = bTy.getShape();
47+
auto cShape = cTy.getShape();
48+
// Check if all 3d or all 2d
49+
if (aShape.size() != 2 && aShape.size() != 3)
50+
return dotOp->emitOpError("expected operands to be 2d or 3d");
51+
if (aShape.size() != bShape.size() || aShape.size() != cShape.size())
52+
return dotOp->emitOpError("expected all operands to have the same rank");
53+
54+
// Check for valid A, B input shapes for dot
55+
if (!dotOp.verifyDims())
56+
return dotOp->emitOpError(
57+
"expected the last dimension of the first operand "
58+
"to be equal to the second-to-last dimension of "
59+
"the second operand");
60+
61+
// Check the batch dimension
62+
if (aShape.size() == 3 && (aShape[0] != cShape[0] || bShape[0] != cShape[0]))
63+
return dotOp->emitOpError("expected the first dimension of the first "
64+
"operand to be equal to the first dimension of "
65+
"the result");
66+
// Check the output shape
67+
if (cShape[cShape.size() - 2] != aShape[aShape.size() - 2] ||
68+
cShape[cShape.size() - 1] != bShape[aShape.size() - 1])
69+
return dotOp->emitOpError(
70+
"expected the output shape to be the concatenation of the last "
71+
"dimension of the first operand and the last dimension of the "
72+
"second ");
73+
return success();
74+
}
75+
3276
} // namespace impl
3377
} // namespace triton
3478
} // namespace mlir

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,28 @@ LogicalResult DotOp::verify() {
309309
bEncoding);
310310
}
311311

312+
bool DotOp::verifyDims() {
313+
auto aShape = this->getA().getType().getShape();
314+
auto bShape = this->getB().getType().getShape();
315+
316+
return aShape[aShape.size() - 1] == bShape[aShape.size() - 2];
317+
}
318+
319+
//-- DotScaledOp --
320+
bool DotScaledOp::verifyDims() {
321+
auto aShape = this->getLhs().getType().getShape();
322+
auto bShape = this->getRhs().getType().getShape();
323+
324+
auto aKdim = aShape[aShape.size() - 1];
325+
auto bKdim = bShape[aShape.size() - 2];
326+
if (this->getLhsType() == ScaleDotElemType::E2M1)
327+
aKdim *= 2;
328+
if (this->getRhsType() == ScaleDotElemType::E2M1)
329+
bKdim *= 2;
330+
331+
return aKdim == bKdim;
332+
}
333+
312334
//-- MakeRangeOp --
313335
OpFoldResult MakeRangeOp::fold(FoldAdaptor adaptor) {
314336
// make_range(start, start + 1) -> constant(start)

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ warpsPerTileV3(DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
120120
// to facilitate use cases like flash attention, allowing reductions within
121121
// the same warp.
122122
if (llvm::find_if(slices, [](Operation *op) {
123-
return op->hasTrait<OpTrait::DotLike>();
123+
return isa<mlir::triton::DotOpInterface>(op);
124124
}) != slices.end())
125125
return {(unsigned)numWarps, 1};
126126

lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ class TMEMAllocWithUnusedInit
5555
};
5656

5757
bool dotSupportsAccInitFlag(Operation *op) {
58-
assert(op->hasTrait<OpTrait::DotLike>() && "Expected a dot-like operation");
58+
assert(isa<DotOpInterface>(op) &&
59+
"Expected an op which implements a DotOpInterface");
60+
5961
if (auto wgDotOp = dyn_cast<triton::nvidia_gpu::WarpGroupDotOp>(op)) {
6062
// Partial accumulation would require a select op to handle the
6163
// initialization that would degrade the performance.
@@ -68,7 +70,9 @@ bool dotSupportsAccInitFlag(Operation *op) {
6870
}
6971

7072
std::pair<Value, Operation *> getAccumulatorUseAndDef(Operation *op) {
71-
assert(op->hasTrait<OpTrait::DotLike>() && "Expected a dot-like operation");
73+
assert(isa<DotOpInterface>(op) &&
74+
"Expected an op which implements a DotOpInterface");
75+
7276
if (auto wgDotOp = dyn_cast<triton::nvidia_gpu::WarpGroupDotOp>(op)) {
7377
return std::make_pair(wgDotOp.getC(), wgDotOp);
7478
}
@@ -90,18 +94,20 @@ std::pair<Value, Operation *> getAccumulatorUseAndDef(Operation *op) {
9094
return std::make_pair(nullptr, nullptr);
9195
return std::make_pair(tmemAlloc.getSrc(), tmemLoad);
9296
}
93-
assert(false && "Unexpected dot-like operation");
97+
assert(false && "Unexpected op which implements a DotOpInterface");
9498
return std::make_pair(nullptr, nullptr);
9599
}
96100

97101
void setUseAccFlag(Operation *op, Value useAcc) {
98-
assert(op->hasTrait<OpTrait::DotLike>() && "Expected a dot-like operation");
102+
assert(isa<DotOpInterface>(op) &&
103+
"Expected an op which implements a DotOpInterface");
104+
99105
if (auto wgDotOp = dyn_cast<triton::nvidia_gpu::WarpGroupDotOp>(op)) {
100106
wgDotOp.getUseCMutable().assign(useAcc);
101107
} else if (auto tc05MmaOp = dyn_cast<triton::nvidia_gpu::TCGen5MMAOp>(op)) {
102108
tc05MmaOp.getUseDMutable().assign(useAcc);
103109
} else {
104-
assert(false && "Unexpected dot-like operation");
110+
assert(false && "Unexpected op which implements a DotOpInterface");
105111
}
106112
}
107113

@@ -159,9 +165,8 @@ class OptimizeAccumulatorInitPass
159165
ModuleOp m = getOperation();
160166
SmallVector<Operation *> mmaOps;
161167
m.walk([&](Operation *op) {
162-
if (op->hasTrait<OpTrait::DotLike>() && dotSupportsAccInitFlag(op)) {
168+
if (isa<DotOpInterface>(op) && dotSupportsAccInitFlag(op))
163169
mmaOps.push_back(op);
164-
}
165170
});
166171

167172
// for each mma op, find where the accumulator is initialized with zero

0 commit comments

Comments
 (0)