@@ -15,6 +15,7 @@ include "mlir/Interfaces/CastInterfaces.td"
1515include "mlir/Interfaces/ControlFlowInterfaces.td"
1616include "mlir/Interfaces/InferIntRangeInterface.td"
1717include "mlir/Interfaces/InferTypeOpInterface.td"
18+ include "mlir/Interfaces/MemOpInterfaces.td"
1819include "mlir/Interfaces/MemorySlotInterfaces.td"
1920include "mlir/Interfaces/ShapedOpInterfaces.td"
2021include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -145,7 +146,8 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
145146 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
146147 Pure,
147148 ViewLikeOpInterface,
148- SameOperandsAndResultType
149+ SameOperandsAndResultType,
150+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
149151 ]> {
150152 let summary =
151153 "assumption that gives alignment information to the input memref";
@@ -456,6 +458,7 @@ def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
456458def MemRef_CastOp : MemRef_Op<"cast", [
457459 DeclareOpInterfaceMethods<CastOpInterface>,
458460 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
461+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
459462 MemRefsNormalizable,
460463 Pure,
461464 SameOperandsAndResultShape,
@@ -1194,6 +1197,7 @@ def LoadOp : MemRef_Op<"load",
11941197 "memref", "result",
11951198 "::llvm::cast<MemRefType>($_self).getElementType()">,
11961199 MemRefsNormalizable,
1200+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
11971201 DeclareOpInterfaceMethods<PromotableMemOpInterface>,
11981202 DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
11991203 let summary = "load operation";
@@ -1284,6 +1288,7 @@ def LoadOp : MemRef_Op<"load",
12841288def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
12851289 DeclareOpInterfaceMethods<CastOpInterface>,
12861290 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1291+ MemorySpaceCastOpInterface,
12871292 MemRefsNormalizable,
12881293 Pure,
12891294 SameOperandsAndResultElementType,
@@ -1302,6 +1307,10 @@ def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
13021307
13031308 If the source and target address spaces are the same, this operation is a noop.
13041309
1310+ Finally, if the target memory-space is the generic/default memory-space,
1311+ then it is assumed this cast can be bubbled down safely. See the docs of
1312+ `MemorySpaceCastOpInterface` interface for more details.
1313+
13051314 Example:
13061315
13071316 ```mlir
@@ -1321,6 +1330,27 @@ def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
13211330
13221331 let extraClassDeclaration = [{
13231332 Value getViewSource() { return getSource(); }
1333+
1334+ //===------------------------------------------------------------------===//
1335+ // MemorySpaceCastConsumerOpInterface
1336+ //===------------------------------------------------------------------===//
1337+ /// Returns the `source` memref.
1338+ TypedValue<PtrLikeTypeInterface> getSourcePtr();
1339+ /// Returns the `dest` memref.
1340+ TypedValue<PtrLikeTypeInterface> getTargetPtr();
1341+ /// Returns whether the memory-space cast is valid. Only casts between
1342+ /// memrefs are considered valid. Further, the `tgt` and `src` should only
1343+ /// differ on the memory-space parameter of the memref type.
1344+ bool isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
1345+ PtrLikeTypeInterface src);
1346+ /// Clones the operation using a new target type and source value.
1347+ MemorySpaceCastOpInterface cloneMemorySpaceCastOp(
1348+ OpBuilder &b, PtrLikeTypeInterface tgt,
1349+ TypedValue<PtrLikeTypeInterface> src);
1350+ /// Returns whether the `source` value can be promoted by the
1351+ /// `MemorySpaceCastConsumerOpInterface::bubbleDownCasts` method. The only
1352+ /// casts the op recognizes as promotable are to the generic memory-space.
1353+ bool isSourcePromotable();
13241354 }];
13251355
13261356 let hasFolder = 1;
@@ -1376,6 +1406,7 @@ def MemRef_PrefetchOp : MemRef_Op<"prefetch"> {
13761406def MemRef_ReinterpretCastOp
13771407 : MemRef_OpWithOffsetSizesAndStrides<"reinterpret_cast", [
13781408 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1409+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
13791410 AttrSizedOperandSegments,
13801411 MemRefsNormalizable,
13811412 Pure,
@@ -1603,6 +1634,7 @@ def MemRef_RankOp : MemRef_Op<"rank", [Pure]> {
16031634
16041635def MemRef_ReshapeOp: MemRef_Op<"reshape", [
16051636 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1637+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
16061638 Pure,
16071639 ViewLikeOpInterface]> {
16081640 let summary = "memref reshape operation";
@@ -1701,6 +1733,7 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
17011733
17021734def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
17031735 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1736+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
17041737 DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
17051738 let summary = "operation to produce a memref with a higher rank.";
17061739 let description = [{
@@ -1822,7 +1855,9 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
18221855}
18231856
18241857def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
1825- DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
1858+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1859+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
1860+ ]> {
18261861 let summary = "operation to produce a memref with a smaller rank.";
18271862 let description = [{
18281863 The `memref.collapse_shape` op produces a new view with a smaller rank
@@ -1929,6 +1964,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
19291964 "memref", "value",
19301965 "::llvm::cast<MemRefType>($_self).getElementType()">,
19311966 MemRefsNormalizable,
1967+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
19321968 DeclareOpInterfaceMethods<PromotableMemOpInterface>,
19331969 DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
19341970 let summary = "store operation";
@@ -2006,6 +2042,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
20062042
20072043def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
20082044 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
2045+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
20092046 DeclareOpInterfaceMethods<ViewLikeOpInterface>,
20102047 AttrSizedOperandSegments,
20112048 OffsetSizeAndStrideOpInterface,
@@ -2281,6 +2318,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
22812318
22822319def MemRef_TransposeOp : MemRef_Op<"transpose", [
22832320 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
2321+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
22842322 Pure]>,
22852323 Arguments<(ins AnyStridedMemRef:$in, AffineMapAttr:$permutation)>,
22862324 Results<(outs AnyStridedMemRef)> {
@@ -2316,6 +2354,7 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [
23162354
23172355def MemRef_ViewOp : MemRef_Op<"view", [
23182356 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
2357+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
23192358 DeclareOpInterfaceMethods<ViewLikeOpInterface>,
23202359 Pure]> {
23212360 let summary = "memref view operation";
@@ -2392,6 +2431,7 @@ def MemRef_ViewOp : MemRef_Op<"view", [
23922431//===----------------------------------------------------------------------===//
23932432
23942433def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
2434+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
23952435 AllTypesMatch<["value", "result"]>,
23962436 TypesMatchWith<"value type matches element type of memref",
23972437 "memref", "value",
0 commit comments