Skip to content

Commit 1186806

Browse files
lezcanoMogball
andauthored
[LAYOUTS] Create a trait that implements Layout equality by comparing the LLs (#5747)
As per title --------- Co-authored-by: Mogball <[email protected]>
1 parent b9eda84 commit 1186806

File tree

5 files changed

+38
-22
lines changed

5 files changed

+38
-22
lines changed

include/triton/Dialect/Triton/IR/Traits.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "mlir/IR/BuiltinTypes.h"
55
#include "mlir/IR/OpDefinition.h"
6+
#include "mlir/Interfaces/InferTypeOpInterface.h"
67
#include "mlir/Support/LogicalResult.h"
78
#include "triton/Dialect/Triton/IR/Types.h"
89

@@ -27,7 +28,7 @@ LogicalResult verifyTensorLayouts(Operation *op);
2728

2829
LogicalResult verifySameOperandsEncoding(Operation *op,
2930
bool allowTensorPointerType = false);
30-
31+
LogicalResult verifyEquivalentType(Type typeA, Type typeB);
3132
LogicalResult
3233
verifySameOperandsAndResultEncoding(Operation *op,
3334
bool allowTensorPointerType = false);

include/triton/Dialect/Triton/IR/TritonInterfaces.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define TRITON_INTERFACES
33

44
include "mlir/IR/OpBase.td"
5+
include "mlir/Interfaces/InferTypeOpInterface.td"
56

67
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
78
def VerifyTensorLayoutsTrait : NativeOpTrait<"VerifyTensorLayoutsTrait">;
@@ -13,4 +14,17 @@ def SameLoadStoreOperandsAndResultShape : NativeOpTrait<"SameLoadStoreOperandsAn
1314
def SameLoadStoreOperandsEncoding : NativeOpTrait<"SameLoadStoreOperandsEncoding">;
1415
def SameLoadStoreOperandsAndResultEncoding : NativeOpTrait<"SameLoadStoreOperandsAndResultEncoding">;
1516

17+
// A trait equivalent to InferTypeOpAdaptor, but that checks for structural
18+
// equivalence of the layouts of the result rather than just layout equality.
19+
def InferTypeOpWithLayoutEquivalence : InferTypeOpAdaptorBase<[{
20+
static bool isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) {
21+
if (lhs.size() != rhs.size())
22+
return false;
23+
return llvm::all_of(llvm::zip(lhs, rhs), [](auto tup) {
24+
auto [lhs, rhs] = tup;
25+
return succeeded(OpTrait::impl::verifyEquivalentType(lhs, rhs));
26+
});
27+
}
28+
}]>;
29+
1630
#endif // TRITON_INTERFACES

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ def TT_SplitOp : TT_Op<"split", [
539539

540540
def TT_TransOp : TT_Op<"trans", [Pure,
541541
TransposeOpInterface,
542-
InferTypeOpAdaptorWithIsCompatible,
542+
InferTypeOpWithLayoutEquivalence,
543543
SameOperandsAndResultElementType]> {
544544

545545
let summary = "rearrange the dimensions of a tensor";

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -235,26 +235,6 @@ LogicalResult TransOp::inferReturnTypes(
235235
return success();
236236
}
237237

238-
bool TransOp::isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) {
239-
assert(lhs.size() == rhs.size());
240-
assert(lhs.size() == 1);
241-
auto lhsType = cast<RankedTensorType>(lhs[0]);
242-
auto rhsType = cast<RankedTensorType>(rhs[0]);
243-
244-
if (lhsType.getShape() != rhsType.getShape())
245-
return false;
246-
247-
auto lhsEnc = lhsType.getEncoding();
248-
auto rhsEnc = rhsType.getEncoding();
249-
// If there's no encoding or the encodings are the same
250-
if (lhsEnc == rhsEnc)
251-
return true;
252-
253-
return cast<DialectInferLayoutInterface>(&lhsEnc.getDialect())
254-
->verifyLayoutsAreEqual(lhsType.getShape(), lhsEnc, rhsEnc, {})
255-
.succeeded();
256-
}
257-
258238
//-- DotOp --
259239
LogicalResult
260240
DotOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,

lib/Dialect/Triton/IR/Traits.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,33 @@
33
#include <numeric>
44

55
#include "mlir/IR/TypeUtilities.h"
6+
#include "triton/Dialect/Triton/IR/Dialect.h"
67
#include "triton/Dialect/Triton/IR/Types.h"
78
#include "triton/Dialect/Triton/IR/Utility.h"
89
#include "llvm/Support/ErrorHandling.h"
910

1011
using namespace mlir;
1112

13+
LogicalResult OpTrait::impl::verifyEquivalentType(Type typeA, Type typeB) {
14+
auto tensorTypeA = dyn_cast<RankedTensorType>(typeA);
15+
auto tensorTypeB = dyn_cast<RankedTensorType>(typeB);
16+
if (!(bool(tensorTypeA) && bool(tensorTypeB)))
17+
return typeA == typeB ? success() : failure();
18+
auto encodingA = tensorTypeA.getEncoding();
19+
auto encodingB = tensorTypeB.getEncoding();
20+
auto shapeA = tensorTypeA.getShape();
21+
auto shapeB = tensorTypeB.getShape();
22+
if (shapeA != shapeB)
23+
return failure();
24+
25+
// If there's no encoding or the encodings are the same
26+
if (encodingA == encodingB)
27+
return success();
28+
29+
return cast<triton::DialectInferLayoutInterface>(&encodingA.getDialect())
30+
->verifyLayoutsAreEqual(shapeA, encodingA, encodingB, {});
31+
}
32+
1233
static LogicalResult verifySameEncoding(Type typeA, Type typeB,
1334
bool allowTensorPointerType) {
1435
// TODO(Keren): the allowTensorPointerType argument is a hack to allow.

0 commit comments

Comments
 (0)