Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/MemOpInterfaces.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Interfaces/ShapedOpInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
Expand Down
40 changes: 38 additions & 2 deletions mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/MemOpInterfaces.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"
include "mlir/Interfaces/ShapedOpInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
Expand Down Expand Up @@ -145,7 +146,8 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
Pure,
ViewLikeOpInterface,
SameOperandsAndResultType
SameOperandsAndResultType,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
]> {
let summary =
"assumption that gives alignment information to the input memref";
Expand Down Expand Up @@ -456,6 +458,7 @@ def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
def MemRef_CastOp : MemRef_Op<"cast", [
DeclareOpInterfaceMethods<CastOpInterface>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
MemRefsNormalizable,
Pure,
SameOperandsAndResultShape,
Expand Down Expand Up @@ -1194,6 +1197,7 @@ def LoadOp : MemRef_Op<"load",
"memref", "result",
"::llvm::cast<MemRefType>($_self).getElementType()">,
MemRefsNormalizable,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
let summary = "load operation";
Expand Down Expand Up @@ -1284,6 +1288,7 @@ def LoadOp : MemRef_Op<"load",
def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
DeclareOpInterfaceMethods<CastOpInterface>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
MemorySpaceCastOpInterface,
MemRefsNormalizable,
Pure,
SameOperandsAndResultElementType,
Expand Down Expand Up @@ -1321,6 +1326,27 @@ def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [

let extraClassDeclaration = [{
Value getViewSource() { return getSource(); }

//===------------------------------------------------------------------===//
// MemorySpaceCastConsumerOpInterface
//===------------------------------------------------------------------===//
/// Returns the `source` memref.
TypedValue<PtrLikeTypeInterface> getSourcePtr();
/// Returns the `dest` memref.
TypedValue<PtrLikeTypeInterface> getTargetPtr();
/// Returns whether the memory-space cast is valid. Only casts between
/// memrefs are considered valid. Further, the `tgt` and `src` should only
/// differ on the memory-space parameter of the memref type.
bool isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
PtrLikeTypeInterface src);
/// Clones the operation using a new target type and source value.
MemorySpaceCastOpInterface cloneMemorySpaceCastOp(
OpBuilder &b, PtrLikeTypeInterface tgt,
TypedValue<PtrLikeTypeInterface> src);
/// Returns whether the `source` value can be promoted by the
/// `MemorySpaceCastConsumerOpInterface::bubbleDownCasts` method. The only
/// casts the op recognizes as promotable are to the generic memory-space.
bool isSourcePromotable();
}];

let hasFolder = 1;
Expand Down Expand Up @@ -1376,6 +1402,7 @@ def MemRef_PrefetchOp : MemRef_Op<"prefetch"> {
def MemRef_ReinterpretCastOp
: MemRef_OpWithOffsetSizesAndStrides<"reinterpret_cast", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
AttrSizedOperandSegments,
MemRefsNormalizable,
Pure,
Expand Down Expand Up @@ -1603,6 +1630,7 @@ def MemRef_RankOp : MemRef_Op<"rank", [Pure]> {

def MemRef_ReshapeOp: MemRef_Op<"reshape", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
Pure,
ViewLikeOpInterface]> {
let summary = "memref reshape operation";
Expand Down Expand Up @@ -1701,6 +1729,7 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :

def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
let summary = "operation to produce a memref with a higher rank.";
let description = [{
Expand Down Expand Up @@ -1822,7 +1851,9 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
}

def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
]> {
let summary = "operation to produce a memref with a smaller rank.";
let description = [{
The `memref.collapse_shape` op produces a new view with a smaller rank
Expand Down Expand Up @@ -1929,6 +1960,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
"memref", "value",
"::llvm::cast<MemRefType>($_self).getElementType()">,
MemRefsNormalizable,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
let summary = "store operation";
Expand Down Expand Up @@ -2006,6 +2038,7 @@ def MemRef_StoreOp : MemRef_Op<"store",

def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
AttrSizedOperandSegments,
OffsetSizeAndStrideOpInterface,
Expand Down Expand Up @@ -2281,6 +2314,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [

def MemRef_TransposeOp : MemRef_Op<"transpose", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
Pure]>,
Arguments<(ins AnyStridedMemRef:$in, AffineMapAttr:$permutation)>,
Results<(outs AnyStridedMemRef)> {
Expand Down Expand Up @@ -2316,6 +2350,7 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [

def MemRef_ViewOp : MemRef_Op<"view", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
Pure]> {
let summary = "memref view operation";
Expand Down Expand Up @@ -2392,6 +2427,7 @@ def MemRef_ViewOp : MemRef_Op<"view", [
//===----------------------------------------------------------------------===//

def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
AllTypesMatch<["value", "result"]>,
TypesMatchWith<"value type matches element type of memref",
"memref", "value",
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/IndexingMapOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/MemOpInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
Expand Down
16 changes: 11 additions & 5 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/IndexingMapOpInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/MemOpInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
Expand Down Expand Up @@ -1246,6 +1247,7 @@ def Vector_TransferReadOp :
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
AttrSizedOperandSegments,
DestinationStyleOpInterface
]>,
Expand Down Expand Up @@ -1493,6 +1495,7 @@ def Vector_TransferWriteOp :
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
AttrSizedOperandSegments,
DestinationStyleOpInterface
]>,
Expand Down Expand Up @@ -1649,6 +1652,7 @@ def Vector_TransferWriteOp :

def Vector_LoadOp : Vector_Op<"load", [
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
]> {
let summary = "reads an n-D slice of memory into an n-D vector";
let description = [{
Expand Down Expand Up @@ -1765,6 +1769,7 @@ def Vector_LoadOp : Vector_Op<"load", [

def Vector_StoreOp : Vector_Op<"store", [
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
]> {
let summary = "writes an n-D vector to an n-D slice of memory";
let description = [{
Expand Down Expand Up @@ -1869,7 +1874,7 @@ def Vector_StoreOp : Vector_Op<"store", [
}

def Vector_MaskedLoadOp :
Vector_Op<"maskedload">,
Vector_Op<"maskedload", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
Expand Down Expand Up @@ -1961,7 +1966,7 @@ def Vector_MaskedLoadOp :
}

def Vector_MaskedStoreOp :
Vector_Op<"maskedstore">,
Vector_Op<"maskedstore", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
Expand Down Expand Up @@ -2041,6 +2046,7 @@ def Vector_MaskedStoreOp :
def Vector_GatherOp :
Vector_Op<"gather", [
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
]>,
Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
Expand Down Expand Up @@ -2144,7 +2150,7 @@ def Vector_GatherOp :
}

def Vector_ScatterOp :
Vector_Op<"scatter">,
Vector_Op<"scatter", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$offsets,
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
Expand Down Expand Up @@ -2229,7 +2235,7 @@ def Vector_ScatterOp :
}

def Vector_ExpandLoadOp :
Vector_Op<"expandload">,
Vector_Op<"expandload", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
Variadic<Index>:$indices,
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
Expand Down Expand Up @@ -2317,7 +2323,7 @@ def Vector_ExpandLoadOp :
}

def Vector_CompressStoreOp :
Vector_Op<"compressstore">,
Vector_Op<"compressstore", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Interfaces/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_mlir_interface(IndexingMapOpInterface)
add_mlir_interface(InferIntRangeInterface)
add_mlir_interface(InferTypeOpInterface)
add_mlir_interface(LoopLikeInterface)
add_mlir_interface(MemOpInterfaces)
add_mlir_interface(ParallelCombiningOpInterface)
add_mlir_interface(RuntimeVerifiableOpInterface)
add_mlir_interface(ShapedOpInterfaces)
Expand Down
36 changes: 36 additions & 0 deletions mlir/include/mlir/Interfaces/MemOpInterfaces.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
//===- MemOpInterfaces.h - Memory operation interfaces ----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains declarations of interfaces for operations that interact
// with memory.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_INTERFACES_MEMOPINTERFACES_H
#define MLIR_INTERFACES_MEMOPINTERFACES_H

#include "mlir/IR/OpDefinition.h"

namespace mlir {
namespace detail {
/// Attempt to verify the given memory space cast operation.
LogicalResult verifyMemorySpaceCastOpInterface(Operation *op);

/// Tries to bubble-down inplace a `MemorySpaceCastOpInterface` operation
/// referenced by `operand`. On success, it returns `std::nullopt`. It
/// returns failure if `operand` doesn't reference a
/// `MemorySpaceCastOpInterface` op.
FailureOr<std::optional<SmallVector<Value>>>
bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results);
} // namespace detail
} // namespace mlir

/// Include the generated interface declarations.
#include "mlir/Interfaces/MemOpInterfaces.h.inc"

#endif // MLIR_INTERFACES_MEMOPINTERFACES_H
Loading