@@ -15,6 +15,7 @@ include "mlir/Interfaces/CastInterfaces.td"
1515include "mlir/Interfaces/ControlFlowInterfaces.td"
1616include "mlir/Interfaces/InferIntRangeInterface.td"
1717include "mlir/Interfaces/InferTypeOpInterface.td"
18+ include "mlir/Interfaces/MemOpInterfaces.td"
1819include "mlir/Interfaces/MemorySlotInterfaces.td"
1920include "mlir/Interfaces/ShapedOpInterfaces.td"
2021include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -145,7 +146,8 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
145146 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
146147 Pure,
147148 ViewLikeOpInterface,
148- SameOperandsAndResultType
149+ SameOperandsAndResultType,
150+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
149151 ]> {
150152 let summary =
151153 "assumption that gives alignment information to the input memref";
@@ -456,6 +458,7 @@ def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
456458def MemRef_CastOp : MemRef_Op<"cast", [
457459 DeclareOpInterfaceMethods<CastOpInterface>,
458460 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
461+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
459462 MemRefsNormalizable,
460463 Pure,
461464 SameOperandsAndResultShape,
@@ -1194,6 +1197,7 @@ def LoadOp : MemRef_Op<"load",
11941197 "memref", "result",
11951198 "::llvm::cast<MemRefType>($_self).getElementType()">,
11961199 MemRefsNormalizable,
1200+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
11971201 DeclareOpInterfaceMethods<PromotableMemOpInterface>,
11981202 DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
11991203 let summary = "load operation";
@@ -1284,6 +1288,7 @@ def LoadOp : MemRef_Op<"load",
12841288def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
12851289 DeclareOpInterfaceMethods<CastOpInterface>,
12861290 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1291+ DeclareOpInterfaceMethods<MemorySpaceCastOpInterface>,
12871292 MemRefsNormalizable,
12881293 Pure,
12891294 SameOperandsAndResultElementType,
@@ -1376,6 +1381,7 @@ def MemRef_PrefetchOp : MemRef_Op<"prefetch"> {
13761381def MemRef_ReinterpretCastOp
13771382 : MemRef_OpWithOffsetSizesAndStrides<"reinterpret_cast", [
13781383 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1384+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
13791385 AttrSizedOperandSegments,
13801386 MemRefsNormalizable,
13811387 Pure,
@@ -1603,6 +1609,7 @@ def MemRef_RankOp : MemRef_Op<"rank", [Pure]> {
16031609
16041610def MemRef_ReshapeOp: MemRef_Op<"reshape", [
16051611 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1612+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
16061613 Pure,
16071614 ViewLikeOpInterface]> {
16081615 let summary = "memref reshape operation";
@@ -1701,6 +1708,7 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
17011708
17021709def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
17031710 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1711+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
17041712 DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
17051713 let summary = "operation to produce a memref with a higher rank.";
17061714 let description = [{
@@ -1822,7 +1830,9 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
18221830}
18231831
18241832def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
1825- DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
1833+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1834+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
1835+ ]> {
18261836 let summary = "operation to produce a memref with a smaller rank.";
18271837 let description = [{
18281838 The `memref.collapse_shape` op produces a new view with a smaller rank
@@ -1929,6 +1939,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
19291939 "memref", "value",
19301940 "::llvm::cast<MemRefType>($_self).getElementType()">,
19311941 MemRefsNormalizable,
1942+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
19321943 DeclareOpInterfaceMethods<PromotableMemOpInterface>,
19331944 DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
19341945 let summary = "store operation";
@@ -2006,6 +2017,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
20062017
20072018def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
20082019 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
2020+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
20092021 DeclareOpInterfaceMethods<ViewLikeOpInterface>,
20102022 AttrSizedOperandSegments,
20112023 OffsetSizeAndStrideOpInterface,
@@ -2281,6 +2293,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
22812293
22822294def MemRef_TransposeOp : MemRef_Op<"transpose", [
22832295 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
2296+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
22842297 Pure]>,
22852298 Arguments<(ins AnyStridedMemRef:$in, AffineMapAttr:$permutation)>,
22862299 Results<(outs AnyStridedMemRef)> {
@@ -2316,6 +2329,7 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [
23162329
23172330def MemRef_ViewOp : MemRef_Op<"view", [
23182331 DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
2332+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
23192333 DeclareOpInterfaceMethods<ViewLikeOpInterface>,
23202334 Pure]> {
23212335 let summary = "memref view operation";
@@ -2392,6 +2406,7 @@ def MemRef_ViewOp : MemRef_Op<"view", [
23922406//===----------------------------------------------------------------------===//
23932407
23942408def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
2409+ DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
23952410 AllTypesMatch<["value", "result"]>,
23962411 TypesMatchWith<"value type matches element type of memref",
23972412 "memref", "value",
0 commit comments