Skip to content

Commit a15b8ca

Browse files
committed
[mlir] Implement memory-space cast operand fusion into consumers
This commit adds functionality to fuse memory-space casts into consumer operations, allowing operations to be performed directly on the original memory-space rather than first casting to a different memory space. Key changes: - Introduce `MemorySpaceCastOpInterface` to handle memory-space cast operations - Create a `FuseMemorySpaceCastsIntoConsumers` pass that identifies and fuses eligible casts - Add implementation for memref and vector operations to handle memory-space cast fusion - Add fuseCastOperands method to relevant operations to support the fusion In particular, in the current implementation only memory-space casts into the default memory-space can be fused. Example: ```mlir func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> { %memspacecast = memref.memory_space_cast %arg0 : memref<4x4xf32, 1> to memref<4x4xf32> %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index %expanded = memref.expand_shape %memspacecast [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32> into memref<4x2x2xf32> %collapsed = memref.collapse_shape %expanded [[0, 1, 2]] : memref<4x2x2xf32> into memref<16xf32> %loaded = memref.load %collapsed[%c0] : memref<16xf32> %added = arith.addf %loaded, %arg2 : f32 memref.store %added, %collapsed[%c0] : memref<16xf32> %atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref<16xf32>) -> f32 return %collapsed : memref<16xf32> } // mlir-opt --fuse-memory-space-casts-into-consumers func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> { %c4 = arith.constant 4 : index %c0 = arith.constant 0 : index %expand_shape = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1> %collapse_shape = memref.collapse_shape %expand_shape [[0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1> %memspacecast = memref.memory_space_cast %collapse_shape : memref<16xf32, 1> to memref<16xf32> %0 = memref.load %collapse_shape[%c0] : memref<16xf32, 1> %1 = arith.addf %0, %arg2 : f32 memref.store %1, %collapse_shape[%c0] : memref<16xf32, 1> %2 = memref.atomic_rmw addf %arg2, %collapse_shape[%c4] : (f32, memref<16xf32, 1>) -> f32 return %memspacecast : memref<16xf32> } ``` Signed-off-by: Fabian Mora <[email protected]>
1 parent 221f8ee commit a15b8ca

File tree

18 files changed

+897
-7
lines changed

18 files changed

+897
-7
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRef.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Interfaces/ControlFlowInterfaces.h"
1919
#include "mlir/Interfaces/InferIntRangeInterface.h"
2020
#include "mlir/Interfaces/InferTypeOpInterface.h"
21+
#include "mlir/Interfaces/MemOpInterfaces.h"
2122
#include "mlir/Interfaces/MemorySlotInterfaces.h"
2223
#include "mlir/Interfaces/ShapedOpInterfaces.h"
2324
#include "mlir/Interfaces/SideEffectInterfaces.h"

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ include "mlir/Interfaces/CastInterfaces.td"
1515
include "mlir/Interfaces/ControlFlowInterfaces.td"
1616
include "mlir/Interfaces/InferIntRangeInterface.td"
1717
include "mlir/Interfaces/InferTypeOpInterface.td"
18+
include "mlir/Interfaces/MemOpInterfaces.td"
1819
include "mlir/Interfaces/MemorySlotInterfaces.td"
1920
include "mlir/Interfaces/ShapedOpInterfaces.td"
2021
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -145,7 +146,8 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
145146
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
146147
Pure,
147148
ViewLikeOpInterface,
148-
SameOperandsAndResultType
149+
SameOperandsAndResultType,
150+
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
149151
]> {
150152
let summary =
151153
"assumption that gives alignment information to the input memref";
@@ -456,6 +458,7 @@ def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
456458
def MemRef_CastOp : MemRef_Op<"cast", [
457459
DeclareOpInterfaceMethods<CastOpInterface>,
458460
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
461+
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
459462
MemRefsNormalizable,
460463
Pure,
461464
SameOperandsAndResultShape,
@@ -1194,6 +1197,7 @@ def LoadOp : MemRef_Op<"load",
11941197
"memref", "result",
11951198
"::llvm::cast<MemRefType>($_self).getElementType()">,
11961199
MemRefsNormalizable,
1200+
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
11971201
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
11981202
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
11991203
let summary = "load operation";
@@ -1284,6 +1288,7 @@ def LoadOp : MemRef_Op<"load",
12841288
def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
12851289
DeclareOpInterfaceMethods<CastOpInterface>,
12861290
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1291+
DeclareOpInterfaceMethods<MemorySpaceCastOpInterface>,
12871292
MemRefsNormalizable,
12881293
Pure,
12891294
SameOperandsAndResultElementType,
@@ -1376,6 +1381,7 @@ def MemRef_PrefetchOp : MemRef_Op<"prefetch"> {
13761381
def MemRef_ReinterpretCastOp
13771382
: MemRef_OpWithOffsetSizesAndStrides<"reinterpret_cast", [
13781383
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1384+
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
13791385
AttrSizedOperandSegments,
13801386
MemRefsNormalizable,
13811387
Pure,
@@ -1603,6 +1609,7 @@ def MemRef_RankOp : MemRef_Op<"rank", [Pure]> {
16031609

16041610
def MemRef_ReshapeOp: MemRef_Op<"reshape", [
16051611
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1612+
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
16061613
Pure,
16071614
ViewLikeOpInterface]> {
16081615
let summary = "memref reshape operation";
@@ -1701,6 +1708,7 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
17011708

17021709
def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
17031710
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1711+
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
17041712
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
17051713
let summary = "operation to produce a memref with a higher rank.";
17061714
let description = [{
@@ -1822,7 +1830,9 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
18221830
}
18231831

18241832
def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
1825-
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
1833+
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1834+
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
1835+
]> {
18261836
let summary = "operation to produce a memref with a smaller rank.";
18271837
let description = [{
18281838
The `memref.collapse_shape` op produces a new view with a smaller rank
@@ -1929,6 +1939,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
19291939
"memref", "value",
19301940
"::llvm::cast<MemRefType>($_self).getElementType()">,
19311941
MemRefsNormalizable,
1942+
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
19321943
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
19331944
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
19341945
let summary = "store operation";
@@ -2006,6 +2017,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
20062017

20072018
def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
20082019
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
2020+
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
20092021
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
20102022
AttrSizedOperandSegments,
20112023
OffsetSizeAndStrideOpInterface,
@@ -2281,6 +2293,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
22812293

22822294
def MemRef_TransposeOp : MemRef_Op<"transpose", [
22832295
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
2296+
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
22842297
Pure]>,
22852298
Arguments<(ins AnyStridedMemRef:$in, AffineMapAttr:$permutation)>,
22862299
Results<(outs AnyStridedMemRef)> {
@@ -2316,6 +2329,7 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [
23162329

23172330
def MemRef_ViewOp : MemRef_Op<"view", [
23182331
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
2332+
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
23192333
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
23202334
Pure]> {
23212335
let summary = "memref view operation";
@@ -2392,6 +2406,7 @@ def MemRef_ViewOp : MemRef_Op<"view", [
23922406
//===----------------------------------------------------------------------===//
23932407

23942408
def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
2409+
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
23952410
AllTypesMatch<["value", "result"]>,
23962411
TypesMatchWith<"value type matches element type of memref",
23972412
"memref", "value",

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2828
#include "mlir/Interfaces/IndexingMapOpInterface.h"
2929
#include "mlir/Interfaces/InferTypeOpInterface.h"
30+
#include "mlir/Interfaces/MemOpInterfaces.h"
3031
#include "mlir/Interfaces/SideEffectInterfaces.h"
3132
#include "mlir/Interfaces/VectorInterfaces.h"
3233
#include "mlir/Interfaces/ViewLikeInterface.h"

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ include "mlir/Interfaces/DestinationStyleOpInterface.td"
2424
include "mlir/Interfaces/IndexingMapOpInterface.td"
2525
include "mlir/Interfaces/InferIntRangeInterface.td"
2626
include "mlir/Interfaces/InferTypeOpInterface.td"
27+
include "mlir/Interfaces/MemOpInterfaces.td"
2728
include "mlir/Interfaces/SideEffectInterfaces.td"
2829
include "mlir/Interfaces/VectorInterfaces.td"
2930
include "mlir/Interfaces/ViewLikeInterface.td"
@@ -1246,6 +1247,7 @@ def Vector_TransferReadOp :
12461247
DeclareOpInterfaceMethods<MaskableOpInterface>,
12471248
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
12481249
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
1250+
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
12491251
AttrSizedOperandSegments,
12501252
DestinationStyleOpInterface
12511253
]>,
@@ -1493,6 +1495,7 @@ def Vector_TransferWriteOp :
14931495
DeclareOpInterfaceMethods<MaskableOpInterface>,
14941496
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
14951497
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
1498+
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
14961499
AttrSizedOperandSegments,
14971500
DestinationStyleOpInterface
14981501
]>,
@@ -1649,6 +1652,7 @@ def Vector_TransferWriteOp :
16491652

16501653
def Vector_LoadOp : Vector_Op<"load", [
16511654
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
1655+
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
16521656
]> {
16531657
let summary = "reads an n-D slice of memory into an n-D vector";
16541658
let description = [{
@@ -1765,6 +1769,7 @@ def Vector_LoadOp : Vector_Op<"load", [
17651769

17661770
def Vector_StoreOp : Vector_Op<"store", [
17671771
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
1772+
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
17681773
]> {
17691774
let summary = "writes an n-D vector to an n-D slice of memory";
17701775
let description = [{
@@ -1869,7 +1874,7 @@ def Vector_StoreOp : Vector_Op<"store", [
18691874
}
18701875

18711876
def Vector_MaskedLoadOp :
1872-
Vector_Op<"maskedload">,
1877+
Vector_Op<"maskedload", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
18731878
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
18741879
Variadic<Index>:$indices,
18751880
VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -1961,7 +1966,7 @@ def Vector_MaskedLoadOp :
19611966
}
19621967

19631968
def Vector_MaskedStoreOp :
1964-
Vector_Op<"maskedstore">,
1969+
Vector_Op<"maskedstore", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
19651970
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
19661971
Variadic<Index>:$indices,
19671972
VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -2041,6 +2046,7 @@ def Vector_MaskedStoreOp :
20412046
def Vector_GatherOp :
20422047
Vector_Op<"gather", [
20432048
DeclareOpInterfaceMethods<MaskableOpInterface>,
2049+
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
20442050
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
20452051
]>,
20462052
Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
@@ -2144,7 +2150,7 @@ def Vector_GatherOp :
21442150
}
21452151

21462152
def Vector_ScatterOp :
2147-
Vector_Op<"scatter">,
2153+
Vector_Op<"scatter", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
21482154
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
21492155
Variadic<Index>:$offsets,
21502156
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
@@ -2229,7 +2235,7 @@ def Vector_ScatterOp :
22292235
}
22302236

22312237
def Vector_ExpandLoadOp :
2232-
Vector_Op<"expandload">,
2238+
Vector_Op<"expandload", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
22332239
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
22342240
Variadic<Index>:$indices,
22352241
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
@@ -2317,7 +2323,7 @@ def Vector_ExpandLoadOp :
23172323
}
23182324

23192325
def Vector_CompressStoreOp :
2320-
Vector_Op<"compressstore">,
2326+
Vector_Op<"compressstore", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
23212327
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
23222328
Variadic<Index>:$indices,
23232329
FixedVectorOfNonZeroRankOf<[I1]>:$mask,

mlir/include/mlir/Interfaces/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_mlir_interface(IndexingMapOpInterface)
88
add_mlir_interface(InferIntRangeInterface)
99
add_mlir_interface(InferTypeOpInterface)
1010
add_mlir_interface(LoopLikeInterface)
11+
add_mlir_interface(MemOpInterfaces)
1112
add_mlir_interface(ParallelCombiningOpInterface)
1213
add_mlir_interface(RuntimeVerifiableOpInterface)
1314
add_mlir_interface(ShapedOpInterfaces)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
//===- MemOpInterfaces.h - Memory operation interfaces ----------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains declarations of interfaces for operations that interact
10+
// with memory.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_INTERFACES_MEMOPINTERFACES_H
15+
#define MLIR_INTERFACES_MEMOPINTERFACES_H
16+
17+
#include "mlir/IR/OpDefinition.h"
18+
19+
namespace mlir {
20+
namespace detail {
21+
/// Attempt to verify the given memory space cast operation.
22+
LogicalResult verifyMemorySpaceCastOpInterface(Operation *op);
23+
24+
/// Tries to fuse inplace a `MemorySpaceCastOpInterface` operation referenced by
25+
/// `operand`. On success, it returns `results`, and sets `modifiedInPlace` to
26+
/// true. It returns failure if `operand` doesn't reference a
27+
/// `MemorySpaceCastOpInterface` op.
28+
FailureOr<SmallVector<Value>>
29+
fuseInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results,
30+
bool &modifiedInPlace);
31+
} // namespace detail
32+
} // namespace mlir
33+
34+
/// Include the generated interface declarations.
35+
#include "mlir/Interfaces/MemOpInterfaces.h.inc"
36+
37+
#endif // MLIR_INTERFACES_MEMOPINTERFACES_H
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
//===- MemOpInterfaces.td - Memory operation interfaces -----*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains interfaces for operations that interact with memory.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_INTERFACES_MEMOPINTERFACES_TD
14+
#define MLIR_INTERFACES_MEMOPINTERFACES_TD
15+
16+
include "mlir/IR/OpBase.td"
17+
include "mlir/Interfaces/SideEffectInterfaces.td"
18+
19+
def FuseMemorySpaceCastConsumerOpInterface :
20+
OpInterface<"FuseMemorySpaceCastConsumerOpInterface"> {
21+
let description = [{
22+
An interface to fuse memory-space cast operands into a consumer operation.
23+
It is the responsibility of the interface to determine which casts can be
24+
fused into the operation.
25+
}];
26+
let cppNamespace = "::mlir";
27+
let methods = [
28+
InterfaceMethod<[{
29+
Attempt to fuse the incoming cast-like operands. Returns `success`
30+
and any new results on fusion success, otherwise it returns failure.
31+
If new results are produced, these must be compatible with the original
32+
operation results.
33+
34+
The `modifiedInPlace` parameter indicates whether the operation was
35+
modified in place. If `false` and the fusion succeeded, then the
36+
interface guarantees it is valid to erase the original operation.
37+
If `true`, then the interface must guarantee no operations were created
38+
by the method, and that no further IR modification is necessary. It is
39+
considered an error if `modifiedInPlace` is true and the fusion failed.
40+
41+
Any implementations of this method must not erase/replace the original
42+
operation, instead it is the caller responsibility to erase or replace
43+
the op with the results provided by the method.
44+
45+
Finally, any implementations of this method have to guarantee that the
46+
IR remains valid at all times.
47+
}],
48+
"::llvm::FailureOr<::llvm::SmallVector<::mlir::Value>>", "fuseCastOperands",
49+
(ins "::mlir::OpBuilder &":$builder, "bool &":$modifiedInPlace)
50+
>,
51+
];
52+
}
53+
54+
def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
55+
let description = [{
56+
An interface for operations that perform memory-space casts. This
57+
interface assumes that the cast operation is `pure`.
58+
59+
These operations expect to have a well-defined ptr-like operand, and
60+
a well-defined target ptr-like result.
61+
}];
62+
let cppNamespace = "::mlir";
63+
let methods = [
64+
InterfaceMethod<[{
65+
Returns the source ptr-like value.
66+
}],
67+
"::mlir::TypedValue<::mlir::PtrLikeTypeInterface>", "getSourcePtr"
68+
>,
69+
InterfaceMethod<[{
70+
Returns the target ptr-like value.
71+
}],
72+
"::mlir::TypedValue<::mlir::PtrLikeTypeInterface>", "getTargetPtr"
73+
>,
74+
InterfaceMethod<[{
75+
Returns whether the memory space cast specified by `tgt` and `src`
76+
is supported.
77+
}],
78+
"bool", "isValidMemorySpaceCast",
79+
(ins "::mlir::PtrLikeTypeInterface":$tgt,
80+
"::mlir::PtrLikeTypeInterface":$src)
81+
>,
82+
InterfaceMethod<[{
83+
Clones the memory space cast op with the given source and target type.
84+
}],
85+
"::mlir::MemorySpaceCastOpInterface", "cloneMemorySpaceCastOp",
86+
(ins "::mlir::OpBuilder &":$builder, "::mlir::Type":$tgt,
87+
"::mlir::Value":$src)
88+
>,
89+
InterfaceMethod<[{
90+
Returns whether the cast allows to be fused.
91+
}],
92+
"bool", "isFusableMemorySpaceCast"
93+
>
94+
];
95+
let verify = [{
96+
return ::mlir::detail::verifyMemorySpaceCastOpInterface($_op);
97+
}];
98+
let dependentTraits = [Pure];
99+
let extraClassDeclaration = [{
100+
/// Returns the underlying `MemorySpaceCastOpInterface` op if `value`
101+
/// is produced by a `MemorySpaceCastOpInterface` op, and
102+
/// `isFusableMemorySpaceCast` returns true, otherwise it returns null.
103+
static ::mlir::MemorySpaceCastOpInterface
104+
getIfFusableCast(::mlir::Value value) {
105+
auto op = ::llvm::dyn_cast_or_null<::mlir::MemorySpaceCastOpInterface>(
106+
value.getDefiningOp());
107+
if (!op || !op.isFusableMemorySpaceCast())
108+
return nullptr;
109+
return op;
110+
}
111+
}];
112+
}
113+
114+
#endif // MLIR_INTERFACES_MEMOPINTERFACES_TD

0 commit comments

Comments
 (0)