Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/MemOpInterfaces.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Interfaces/ShapedOpInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
Expand Down
19 changes: 17 additions & 2 deletions mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/MemOpInterfaces.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"
include "mlir/Interfaces/ShapedOpInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
Expand Down Expand Up @@ -145,7 +146,8 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
Pure,
ViewLikeOpInterface,
SameOperandsAndResultType
SameOperandsAndResultType,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
]> {
let summary =
"assumption that gives alignment information to the input memref";
Expand Down Expand Up @@ -456,6 +458,7 @@ def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
def MemRef_CastOp : MemRef_Op<"cast", [
DeclareOpInterfaceMethods<CastOpInterface>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
MemRefsNormalizable,
Pure,
SameOperandsAndResultShape,
Expand Down Expand Up @@ -1194,6 +1197,7 @@ def LoadOp : MemRef_Op<"load",
"memref", "result",
"::llvm::cast<MemRefType>($_self).getElementType()">,
MemRefsNormalizable,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
let summary = "load operation";
Expand Down Expand Up @@ -1284,6 +1288,7 @@ def LoadOp : MemRef_Op<"load",
def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
DeclareOpInterfaceMethods<CastOpInterface>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<MemorySpaceCastOpInterface>,
MemRefsNormalizable,
Pure,
SameOperandsAndResultElementType,
Expand Down Expand Up @@ -1376,6 +1381,7 @@ def MemRef_PrefetchOp : MemRef_Op<"prefetch"> {
def MemRef_ReinterpretCastOp
: MemRef_OpWithOffsetSizesAndStrides<"reinterpret_cast", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
AttrSizedOperandSegments,
MemRefsNormalizable,
Pure,
Expand Down Expand Up @@ -1603,6 +1609,7 @@ def MemRef_RankOp : MemRef_Op<"rank", [Pure]> {

def MemRef_ReshapeOp: MemRef_Op<"reshape", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
Pure,
ViewLikeOpInterface]> {
let summary = "memref reshape operation";
Expand Down Expand Up @@ -1701,6 +1708,7 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :

def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
let summary = "operation to produce a memref with a higher rank.";
let description = [{
Expand Down Expand Up @@ -1822,7 +1830,9 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
}

def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
]> {
let summary = "operation to produce a memref with a smaller rank.";
let description = [{
The `memref.collapse_shape` op produces a new view with a smaller rank
Expand Down Expand Up @@ -1929,6 +1939,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
"memref", "value",
"::llvm::cast<MemRefType>($_self).getElementType()">,
MemRefsNormalizable,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
let summary = "store operation";
Expand Down Expand Up @@ -2006,6 +2017,7 @@ def MemRef_StoreOp : MemRef_Op<"store",

def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
AttrSizedOperandSegments,
OffsetSizeAndStrideOpInterface,
Expand Down Expand Up @@ -2281,6 +2293,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [

def MemRef_TransposeOp : MemRef_Op<"transpose", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
Pure]>,
Arguments<(ins AnyStridedMemRef:$in, AffineMapAttr:$permutation)>,
Results<(outs AnyStridedMemRef)> {
Expand Down Expand Up @@ -2316,6 +2329,7 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [

def MemRef_ViewOp : MemRef_Op<"view", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
Pure]> {
let summary = "memref view operation";
Expand Down Expand Up @@ -2392,6 +2406,7 @@ def MemRef_ViewOp : MemRef_Op<"view", [
//===----------------------------------------------------------------------===//

def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
AllTypesMatch<["value", "result"]>,
TypesMatchWith<"value type matches element type of memref",
"memref", "value",
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/IndexingMapOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/MemOpInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
Expand Down
16 changes: 11 additions & 5 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/IndexingMapOpInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/MemOpInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
Expand Down Expand Up @@ -1246,6 +1247,7 @@ def Vector_TransferReadOp :
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
AttrSizedOperandSegments,
DestinationStyleOpInterface
]>,
Expand Down Expand Up @@ -1493,6 +1495,7 @@ def Vector_TransferWriteOp :
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
AttrSizedOperandSegments,
DestinationStyleOpInterface
]>,
Expand Down Expand Up @@ -1649,6 +1652,7 @@ def Vector_TransferWriteOp :

def Vector_LoadOp : Vector_Op<"load", [
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
]> {
let summary = "reads an n-D slice of memory into an n-D vector";
let description = [{
Expand Down Expand Up @@ -1765,6 +1769,7 @@ def Vector_LoadOp : Vector_Op<"load", [

def Vector_StoreOp : Vector_Op<"store", [
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
]> {
let summary = "writes an n-D vector to an n-D slice of memory";
let description = [{
Expand Down Expand Up @@ -1869,7 +1874,7 @@ def Vector_StoreOp : Vector_Op<"store", [
}

def Vector_MaskedLoadOp :
Vector_Op<"maskedload">,
Vector_Op<"maskedload", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
Expand Down Expand Up @@ -1961,7 +1966,7 @@ def Vector_MaskedLoadOp :
}

def Vector_MaskedStoreOp :
Vector_Op<"maskedstore">,
Vector_Op<"maskedstore", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
Expand Down Expand Up @@ -2041,6 +2046,7 @@ def Vector_MaskedStoreOp :
def Vector_GatherOp :
Vector_Op<"gather", [
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
]>,
Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
Expand Down Expand Up @@ -2144,7 +2150,7 @@ def Vector_GatherOp :
}

def Vector_ScatterOp :
Vector_Op<"scatter">,
Vector_Op<"scatter", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$offsets,
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
Expand Down Expand Up @@ -2229,7 +2235,7 @@ def Vector_ScatterOp :
}

def Vector_ExpandLoadOp :
Vector_Op<"expandload">,
Vector_Op<"expandload", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
Variadic<Index>:$indices,
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
Expand Down Expand Up @@ -2317,7 +2323,7 @@ def Vector_ExpandLoadOp :
}

def Vector_CompressStoreOp :
Vector_Op<"compressstore">,
Vector_Op<"compressstore", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Interfaces/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_mlir_interface(IndexingMapOpInterface)
add_mlir_interface(InferIntRangeInterface)
add_mlir_interface(InferTypeOpInterface)
add_mlir_interface(LoopLikeInterface)
add_mlir_interface(MemOpInterfaces)
add_mlir_interface(ParallelCombiningOpInterface)
add_mlir_interface(RuntimeVerifiableOpInterface)
add_mlir_interface(ShapedOpInterfaces)
Expand Down
36 changes: 36 additions & 0 deletions mlir/include/mlir/Interfaces/MemOpInterfaces.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
//===- MemOpInterfaces.h - Memory operation interfaces ----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains declarations of interfaces for operations that interact
// with memory.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_INTERFACES_MEMOPINTERFACES_H
#define MLIR_INTERFACES_MEMOPINTERFACES_H

#include "mlir/IR/OpDefinition.h"

namespace mlir {
namespace detail {
/// Attempt to verify the given memory space cast operation.
LogicalResult verifyMemorySpaceCastOpInterface(Operation *op);

/// Tries to bubble-down inplace a `MemorySpaceCastOpInterface` operation
/// referenced by `operand`. On success, it returns `std::nullopt`. It
/// returns failure if `operand` doesn't reference a
/// `MemorySpaceCastOpInterface` op.
FailureOr<std::optional<SmallVector<Value>>>
bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results);
} // namespace detail
} // namespace mlir

/// Include the generated interface declarations.
#include "mlir/Interfaces/MemOpInterfaces.h.inc"

#endif // MLIR_INTERFACES_MEMOPINTERFACES_H
117 changes: 117 additions & 0 deletions mlir/include/mlir/Interfaces/MemOpInterfaces.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
//===- MemOpInterfaces.td - Memory operation interfaces -----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains interfaces for operations that interact with memory.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_INTERFACES_MEMOPINTERFACES_TD
#define MLIR_INTERFACES_MEMOPINTERFACES_TD

include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

def MemorySpaceCastConsumerOpInterface :
OpInterface<"MemorySpaceCastConsumerOpInterface"> {
let description = [{
An interface for operations that can consume memory-space cast-like
operations.
}];
let cppNamespace = "::mlir";
let methods = [
InterfaceMethod<[{
Attempt to bubble-down the incoming cast-like operands. On success
returns a `std::optional<SmallVector<Value>>`, otherwise it returns
failure. If the optional is `std::nullopt` then the cast was performed
in place, otherwise the method returns a list of replacement values.
If new results are produced, these must be compatible with the original
operation results.

If the operation was not modified in place, then the interface
guarantees it is valid to erase the original operation.
If the operation was modified in place, then the interface must
guarantee no operations were created by the method, and that no further
IR modification is necessary.

Any implementations of this method must not erase/replace the original
operation, instead it is the caller responsibility to erase or replace
the op with the results provided by the method.

Finally, any implementations of this method have to guarantee that the
IR remains valid at all times.
}],
"::llvm::FailureOr<std::optional<::llvm::SmallVector<::mlir::Value>>>",
"bubbleDownCasts",
(ins "::mlir::OpBuilder &":$builder)
>,
];
}

def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> {
let description = [{
An interface for operations that perform memory-space casts. This
interface assumes that the cast operation is `pure`.

These operations expect to have a well-defined ptr-like operand, and
a well-defined target ptr-like result.
}];
let cppNamespace = "::mlir";
let methods = [
InterfaceMethod<[{
Returns the source ptr-like value.
}],
"::mlir::TypedValue<::mlir::PtrLikeTypeInterface>", "getSourcePtr"
>,
InterfaceMethod<[{
Returns the target ptr-like value.
}],
"::mlir::TypedValue<::mlir::PtrLikeTypeInterface>", "getTargetPtr"
>,
InterfaceMethod<[{
Returns whether the memory space cast specified by `tgt` and `src`
is supported.
}],
"bool", "isValidMemorySpaceCast",
(ins "::mlir::PtrLikeTypeInterface":$tgt,
"::mlir::PtrLikeTypeInterface":$src)
>,
InterfaceMethod<[{
Clones the memory space cast op with the given source and target type.
}],
"::mlir::MemorySpaceCastOpInterface", "cloneMemorySpaceCastOp",
(ins "::mlir::OpBuilder &":$builder, "::mlir::PtrLikeTypeInterface":$tgt,
"::mlir::TypedValue<::mlir::PtrLikeTypeInterface>":$src)
>,
InterfaceMethod<[{
Returns whether the source pointer of the memory-space cast can be used
by the `MemorySpaceCastConsumerOpInterface::bubbleDownCasts` method to
promote the source pointer and bubble down the cast.
}],
"bool", "isSourcePromotable"
>
];
let verify = [{
return ::mlir::detail::verifyMemorySpaceCastOpInterface($_op);
}];
let dependentTraits = [Pure];
let extraClassDeclaration = [{
/// Returns the underlying `MemorySpaceCastOpInterface` op if `value`
/// is produced by a `MemorySpaceCastOpInterface` op, and
/// `isSourcePromotable` returns true, otherwise it returns null.
static ::mlir::MemorySpaceCastOpInterface
getIfPromotableCast(::mlir::Value value) {
auto op = ::llvm::dyn_cast_or_null<::mlir::MemorySpaceCastOpInterface>(
value.getDefiningOp());
if (!op || !op.isSourcePromotable())
return nullptr;
return op;
}
}];
}

#endif // MLIR_INTERFACES_MEMOPINTERFACES_TD
Loading