File tree Expand file tree Collapse file tree 5 files changed +38
-22
lines changed
include/triton/Dialect/Triton/IR Expand file tree Collapse file tree 5 files changed +38
-22
lines changed Original file line number Diff line number Diff line change 33
44#include " mlir/IR/BuiltinTypes.h"
55#include " mlir/IR/OpDefinition.h"
6+ #include " mlir/Interfaces/InferTypeOpInterface.h"
67#include " mlir/Support/LogicalResult.h"
78#include " triton/Dialect/Triton/IR/Types.h"
89
@@ -27,7 +28,7 @@ LogicalResult verifyTensorLayouts(Operation *op);
2728
2829LogicalResult verifySameOperandsEncoding (Operation *op,
2930 bool allowTensorPointerType = false );
30-
31+ LogicalResult verifyEquivalentType (Type typeA, Type typeB);
3132LogicalResult
3233verifySameOperandsAndResultEncoding (Operation *op,
3334 bool allowTensorPointerType = false );
Original file line number Diff line number Diff line change 22#define TRITON_INTERFACES
33
44include "mlir/IR/OpBase.td"
5+ include "mlir/Interfaces/InferTypeOpInterface.td"
56
67def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
78def VerifyTensorLayoutsTrait : NativeOpTrait<"VerifyTensorLayoutsTrait">;
@@ -13,4 +14,17 @@ def SameLoadStoreOperandsAndResultShape : NativeOpTrait<"SameLoadStoreOperandsAn
1314def SameLoadStoreOperandsEncoding : NativeOpTrait<"SameLoadStoreOperandsEncoding">;
1415def SameLoadStoreOperandsAndResultEncoding : NativeOpTrait<"SameLoadStoreOperandsAndResultEncoding">;
1516
17+ // A trait equivalent to InferTypeOpAdaptor, but that checks for structural
18+ // equivalence of the layouts of the result rather than just layout equality.
19+ def InferTypeOpWithLayoutEquivalence : InferTypeOpAdaptorBase<[{
20+ static bool isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) {
21+ if (lhs.size() != rhs.size())
22+ return false;
23+ return llvm::all_of(llvm::zip(lhs, rhs), [](auto tup) {
24+ auto [lhs, rhs] = tup;
25+ return succeeded(OpTrait::impl::verifyEquivalentType(lhs, rhs));
26+ });
27+ }
28+ }]>;
29+
1630#endif // TRITON_INTERFACES
Original file line number Diff line number Diff line change @@ -539,7 +539,7 @@ def TT_SplitOp : TT_Op<"split", [
539539
540540def TT_TransOp : TT_Op<"trans", [Pure,
541541 TransposeOpInterface,
542- InferTypeOpAdaptorWithIsCompatible ,
542+ InferTypeOpWithLayoutEquivalence ,
543543 SameOperandsAndResultElementType]> {
544544
545545 let summary = "rearrange the dimensions of a tensor";
Original file line number Diff line number Diff line change @@ -235,26 +235,6 @@ LogicalResult TransOp::inferReturnTypes(
235235 return success ();
236236}
237237
238- bool TransOp::isCompatibleReturnTypes (TypeRange lhs, TypeRange rhs) {
239- assert (lhs.size () == rhs.size ());
240- assert (lhs.size () == 1 );
241- auto lhsType = cast<RankedTensorType>(lhs[0 ]);
242- auto rhsType = cast<RankedTensorType>(rhs[0 ]);
243-
244- if (lhsType.getShape () != rhsType.getShape ())
245- return false ;
246-
247- auto lhsEnc = lhsType.getEncoding ();
248- auto rhsEnc = rhsType.getEncoding ();
249- // If there's no encoding or the encodings are the same
250- if (lhsEnc == rhsEnc)
251- return true ;
252-
253- return cast<DialectInferLayoutInterface>(&lhsEnc.getDialect ())
254- ->verifyLayoutsAreEqual (lhsType.getShape (), lhsEnc, rhsEnc, {})
255- .succeeded ();
256- }
257-
258238// -- DotOp --
259239LogicalResult
260240DotOp::inferReturnTypes (MLIRContext *context, std::optional<Location> location,
Original file line number Diff line number Diff line change 33#include < numeric>
44
55#include " mlir/IR/TypeUtilities.h"
6+ #include " triton/Dialect/Triton/IR/Dialect.h"
67#include " triton/Dialect/Triton/IR/Types.h"
78#include " triton/Dialect/Triton/IR/Utility.h"
89#include " llvm/Support/ErrorHandling.h"
910
1011using namespace mlir ;
1112
13+ LogicalResult OpTrait::impl::verifyEquivalentType (Type typeA, Type typeB) {
14+ auto tensorTypeA = dyn_cast<RankedTensorType>(typeA);
15+ auto tensorTypeB = dyn_cast<RankedTensorType>(typeB);
16+ if (!(bool (tensorTypeA) && bool (tensorTypeB)))
17+ return typeA == typeB ? success () : failure ();
18+ auto encodingA = tensorTypeA.getEncoding ();
19+ auto encodingB = tensorTypeB.getEncoding ();
20+ auto shapeA = tensorTypeA.getShape ();
21+ auto shapeB = tensorTypeB.getShape ();
22+ if (shapeA != shapeB)
23+ return failure ();
24+
25+ // If there's no encoding or the encodings are the same
26+ if (encodingA == encodingB)
27+ return success ();
28+
29+ return cast<triton::DialectInferLayoutInterface>(&encodingA.getDialect ())
30+ ->verifyLayoutsAreEqual (shapeA, encodingA, encodingB, {});
31+ }
32+
1233static LogicalResult verifySameEncoding (Type typeA, Type typeB,
1334 bool allowTensorPointerType) {
1435 // TODO(Keren): the allowTensorPointerType argument is a hack to allow.
You can’t perform that action at this time.
0 commit comments