Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/MemOpInterfaces.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Interfaces/ShapedOpInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
Expand Down
19 changes: 17 additions & 2 deletions mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/MemOpInterfaces.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"
include "mlir/Interfaces/ShapedOpInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
Expand Down Expand Up @@ -145,7 +146,8 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
Pure,
ViewLikeOpInterface,
SameOperandsAndResultType
SameOperandsAndResultType,
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
]> {
let summary =
"assumption that gives alignment information to the input memref";
Expand Down Expand Up @@ -456,6 +458,7 @@ def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
def MemRef_CastOp : MemRef_Op<"cast", [
DeclareOpInterfaceMethods<CastOpInterface>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
MemRefsNormalizable,
Pure,
SameOperandsAndResultShape,
Expand Down Expand Up @@ -1194,6 +1197,7 @@ def LoadOp : MemRef_Op<"load",
"memref", "result",
"::llvm::cast<MemRefType>($_self).getElementType()">,
MemRefsNormalizable,
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
let summary = "load operation";
Expand Down Expand Up @@ -1284,6 +1288,7 @@ def LoadOp : MemRef_Op<"load",
def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
DeclareOpInterfaceMethods<CastOpInterface>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<MemorySpaceCastOpInterface>,
MemRefsNormalizable,
Pure,
SameOperandsAndResultElementType,
Expand Down Expand Up @@ -1376,6 +1381,7 @@ def MemRef_PrefetchOp : MemRef_Op<"prefetch"> {
def MemRef_ReinterpretCastOp
: MemRef_OpWithOffsetSizesAndStrides<"reinterpret_cast", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
AttrSizedOperandSegments,
MemRefsNormalizable,
Pure,
Expand Down Expand Up @@ -1603,6 +1609,7 @@ def MemRef_RankOp : MemRef_Op<"rank", [Pure]> {

def MemRef_ReshapeOp: MemRef_Op<"reshape", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
Pure,
ViewLikeOpInterface]> {
let summary = "memref reshape operation";
Expand Down Expand Up @@ -1701,6 +1708,7 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :

def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
let summary = "operation to produce a memref with a higher rank.";
let description = [{
Expand Down Expand Up @@ -1822,7 +1830,9 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
}

def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
]> {
let summary = "operation to produce a memref with a smaller rank.";
let description = [{
The `memref.collapse_shape` op produces a new view with a smaller rank
Expand Down Expand Up @@ -1929,6 +1939,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
"memref", "value",
"::llvm::cast<MemRefType>($_self).getElementType()">,
MemRefsNormalizable,
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
let summary = "store operation";
Expand Down Expand Up @@ -2006,6 +2017,7 @@ def MemRef_StoreOp : MemRef_Op<"store",

def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
AttrSizedOperandSegments,
OffsetSizeAndStrideOpInterface,
Expand Down Expand Up @@ -2281,6 +2293,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [

def MemRef_TransposeOp : MemRef_Op<"transpose", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
Pure]>,
Arguments<(ins AnyStridedMemRef:$in, AffineMapAttr:$permutation)>,
Results<(outs AnyStridedMemRef)> {
Expand Down Expand Up @@ -2316,6 +2329,7 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [

def MemRef_ViewOp : MemRef_Op<"view", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
Pure]> {
let summary = "memref view operation";
Expand Down Expand Up @@ -2392,6 +2406,7 @@ def MemRef_ViewOp : MemRef_Op<"view", [
//===----------------------------------------------------------------------===//

def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
AllTypesMatch<["value", "result"]>,
TypesMatchWith<"value type matches element type of memref",
"memref", "value",
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/IndexingMapOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/MemOpInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
Expand Down
16 changes: 11 additions & 5 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/IndexingMapOpInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/MemOpInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
Expand Down Expand Up @@ -1246,6 +1247,7 @@ def Vector_TransferReadOp :
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
AttrSizedOperandSegments,
DestinationStyleOpInterface
]>,
Expand Down Expand Up @@ -1493,6 +1495,7 @@ def Vector_TransferWriteOp :
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
AttrSizedOperandSegments,
DestinationStyleOpInterface
]>,
Expand Down Expand Up @@ -1649,6 +1652,7 @@ def Vector_TransferWriteOp :

def Vector_LoadOp : Vector_Op<"load", [
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
]> {
let summary = "reads an n-D slice of memory into an n-D vector";
let description = [{
Expand Down Expand Up @@ -1765,6 +1769,7 @@ def Vector_LoadOp : Vector_Op<"load", [

def Vector_StoreOp : Vector_Op<"store", [
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>
]> {
let summary = "writes an n-D vector to an n-D slice of memory";
let description = [{
Expand Down Expand Up @@ -1869,7 +1874,7 @@ def Vector_StoreOp : Vector_Op<"store", [
}

def Vector_MaskedLoadOp :
Vector_Op<"maskedload">,
Vector_Op<"maskedload", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
Expand Down Expand Up @@ -1961,7 +1966,7 @@ def Vector_MaskedLoadOp :
}

def Vector_MaskedStoreOp :
Vector_Op<"maskedstore">,
Vector_Op<"maskedstore", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
Expand Down Expand Up @@ -2041,6 +2046,7 @@ def Vector_MaskedStoreOp :
def Vector_GatherOp :
Vector_Op<"gather", [
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
]>,
Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
Expand Down Expand Up @@ -2144,7 +2150,7 @@ def Vector_GatherOp :
}

def Vector_ScatterOp :
Vector_Op<"scatter">,
Vector_Op<"scatter", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$offsets,
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
Expand Down Expand Up @@ -2229,7 +2235,7 @@ def Vector_ScatterOp :
}

def Vector_ExpandLoadOp :
Vector_Op<"expandload">,
Vector_Op<"expandload", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
Variadic<Index>:$indices,
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
Expand Down Expand Up @@ -2317,7 +2323,7 @@ def Vector_ExpandLoadOp :
}

def Vector_CompressStoreOp :
Vector_Op<"compressstore">,
Vector_Op<"compressstore", [DeclareOpInterfaceMethods<FuseMemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Interfaces/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_mlir_interface(IndexingMapOpInterface)
add_mlir_interface(InferIntRangeInterface)
add_mlir_interface(InferTypeOpInterface)
add_mlir_interface(LoopLikeInterface)
add_mlir_interface(MemOpInterfaces)
add_mlir_interface(ParallelCombiningOpInterface)
add_mlir_interface(RuntimeVerifiableOpInterface)
add_mlir_interface(ShapedOpInterfaces)
Expand Down
37 changes: 37 additions & 0 deletions mlir/include/mlir/Interfaces/MemOpInterfaces.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//===- MemOpInterfaces.h - Memory operation interfaces ----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains declarations of interfaces for operations that interact
// with memory.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_INTERFACES_MEMOPINTERFACES_H
#define MLIR_INTERFACES_MEMOPINTERFACES_H

#include "mlir/IR/OpDefinition.h"

namespace mlir {
namespace detail {
/// Attempt to verify the given memory space cast operation.
LogicalResult verifyMemorySpaceCastOpInterface(Operation *op);

/// Tries to fuse inplace a `MemorySpaceCastOpInterface` operation referenced by
/// `operand`. On success, it returns `results`, and sets `modifiedInPlace` to
/// true. It returns failure if `operand` doesn't reference a
/// `MemorySpaceCastOpInterface` op.
FailureOr<SmallVector<Value>>
fuseInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results,
bool &modifiedInPlace);
} // namespace detail
} // namespace mlir

/// Include the generated interface declarations.
#include "mlir/Interfaces/MemOpInterfaces.h.inc"

#endif // MLIR_INTERFACES_MEMOPINTERFACES_H
114 changes: 114 additions & 0 deletions mlir/include/mlir/Interfaces/MemOpInterfaces.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
//===- MemOpInterfaces.td - Memory operation interfaces -----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains interfaces for operations that interact with memory.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_INTERFACES_MEMOPINTERFACES_TD
#define MLIR_INTERFACES_MEMOPINTERFACES_TD

include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

def FuseMemorySpaceCastConsumerOpInterface :
OpInterface<"FuseMemorySpaceCastConsumerOpInterface"> {
let description = [{
An interface to fuse memory-space cast operands into a consumer operation.
It is the responsibility of the interface to determine which casts can be
fused into the operation.
}];
let cppNamespace = "::mlir";
let methods = [
InterfaceMethod<[{
Attempt to fuse the incoming cast-like operands. Returns `success`
and any new results on fusion success, otherwise it returns failure.
If new results are produced, these must be compatible with the original
operation results.

The `modifiedInPlace` parameter indicates whether the operation was
modified in place. If `false` and the fusion succeeded, then the
interface guarantees it is valid to erase the original operation.
If `true`, then the interface must guarantee no operations were created
by the method, and that no further IR modification is necessary. It is
considered an error if `modifiedInPlace` is true and the fusion failed.

Any implementations of this method must not erase/replace the original
operation, instead it is the caller responsibility to erase or replace
the op with the results provided by the method.

Finally, any implementations of this method have to guarantee that the
IR remains valid at all times.
}],
"::llvm::FailureOr<::llvm::SmallVector<::mlir::Value>>", "fuseCastOperands",
(ins "::mlir::OpBuilder &":$builder, "bool &":$modifiedInPlace)
>,
];
}

def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
let description = [{
An interface for operations that perform memory-space casts. This
interface assumes that the cast operation is `pure`.

These operations expect to have a well-defined ptr-like operand, and
a well-defined target ptr-like result.
}];
let cppNamespace = "::mlir";
let methods = [
InterfaceMethod<[{
Returns the source ptr-like value.
}],
"::mlir::TypedValue<::mlir::PtrLikeTypeInterface>", "getSourcePtr"
>,
InterfaceMethod<[{
Returns the target ptr-like value.
}],
"::mlir::TypedValue<::mlir::PtrLikeTypeInterface>", "getTargetPtr"
>,
InterfaceMethod<[{
Returns whether the memory space cast specified by `tgt` and `src`
is supported.
}],
"bool", "isValidMemorySpaceCast",
(ins "::mlir::PtrLikeTypeInterface":$tgt,
"::mlir::PtrLikeTypeInterface":$src)
>,
InterfaceMethod<[{
Clones the memory space cast op with the given source and target type.
}],
"::mlir::MemorySpaceCastOpInterface", "cloneMemorySpaceCastOp",
(ins "::mlir::OpBuilder &":$builder, "::mlir::Type":$tgt,
"::mlir::Value":$src)
>,
InterfaceMethod<[{
Returns whether the cast allows to be fused.
}],
"bool", "isFusableMemorySpaceCast"
>
];
let verify = [{
return ::mlir::detail::verifyMemorySpaceCastOpInterface($_op);
}];
let dependentTraits = [Pure];
let extraClassDeclaration = [{
/// Returns the underlying `MemorySpaceCastOpInterface` op if `value`
/// is produced by a `MemorySpaceCastOpInterface` op, and
/// `isFusableMemorySpaceCast` returns true, otherwise it returns null.
static ::mlir::MemorySpaceCastOpInterface
getIfFusableCast(::mlir::Value value) {
auto op = ::llvm::dyn_cast_or_null<::mlir::MemorySpaceCastOpInterface>(
value.getDefiningOp());
if (!op || !op.isFusableMemorySpaceCast())
return nullptr;
return op;
}
}];
}

#endif // MLIR_INTERFACES_MEMOPINTERFACES_TD
Loading