Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/AlignmentAttrInterface.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
Expand Down
12 changes: 9 additions & 3 deletions mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

include "mlir/Dialect/Arith/IR/ArithBase.td"
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
include "mlir/Interfaces/AlignmentAttrInterface.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
Expand Down Expand Up @@ -64,7 +65,8 @@ class AllocLikeOp<string mnemonic,
list<Trait> traits = []> :
MemRef_Op<mnemonic,
!listconcat([
AttrSizedOperandSegments
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
], traits)> {

let arguments = (ins Variadic<Index>:$dynamicSizes,
Expand Down Expand Up @@ -267,7 +269,8 @@ def MemRef_AllocOp : AllocLikeOp<"alloc", DefaultResource, [
//===----------------------------------------------------------------------===//


def MemRef_ReallocOp : MemRef_Op<"realloc"> {
def MemRef_ReallocOp : MemRef_Op<"realloc",
[DeclareOpInterfaceMethods<AlignmentAttrOpInterface>]> {
let summary = "memory reallocation operation";
let description = [{
The `realloc` operation changes the size of a memory region. The memory
Expand Down Expand Up @@ -1157,7 +1160,8 @@ def MemRef_GetGlobalOp : MemRef_Op<"get_global",
// GlobalOp
//===----------------------------------------------------------------------===//

def MemRef_GlobalOp : MemRef_Op<"global", [Symbol]> {
def MemRef_GlobalOp : MemRef_Op<"global", [Symbol,
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>]> {
let summary = "declare or define a global memref variable";
let description = [{
The `memref.global` operation declares or defines a named global memref
Expand Down Expand Up @@ -1232,6 +1236,7 @@ def LoadOp : MemRef_Op<"load",
"memref", "result",
"::llvm::cast<MemRefType>($_self).getElementType()">,
MemRefsNormalizable,
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
Expand Down Expand Up @@ -1999,6 +2004,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
"memref", "value",
"::llvm::cast<MemRefType>($_self).getElementType()">,
MemRefsNormalizable,
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#ifndef MLIR_DIALECT_SPIRV_IR_COOPERATIVE_MATRIX_OPS
#define MLIR_DIALECT_SPIRV_IR_COOPERATIVE_MATRIX_OPS

include "mlir/Interfaces/AlignmentAttrInterface.td"

//===----------------------------------------------------------------------===//
// SPV_KHR_cooperative_matrix extension ops.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -62,7 +64,7 @@ def SPIRV_KHRCooperativeMatrixLengthOp :

// -----

def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad", []> {
def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad", [DeclareOpInterfaceMethods<AlignmentAttrOpInterface>]> {
let summary = "Loads a cooperative matrix through a pointer";

let description = [{
Expand Down Expand Up @@ -148,7 +150,7 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"

// -----

def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStore", []> {
def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStore", [DeclareOpInterfaceMethods<AlignmentAttrOpInterface>]> {
let summary = "Stores a cooperative matrix through a pointer";

let description = [{
Expand Down
8 changes: 5 additions & 3 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#define MLIR_DIALECT_SPIRV_IR_MEMORY_OPS

include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
include "mlir/Interfaces/AlignmentAttrInterface.td"


// -----

Expand Down Expand Up @@ -79,7 +81,7 @@ def SPIRV_AccessChainOp : SPIRV_Op<"AccessChain", [Pure]> {

// -----

def SPIRV_CopyMemoryOp : SPIRV_Op<"CopyMemory", []> {
def SPIRV_CopyMemoryOp : SPIRV_Op<"CopyMemory", [DeclareOpInterfaceMethods<AlignmentAttrOpInterface>]> {
let summary = [{
Copy from the memory pointed to by Source to the memory pointed to by
Target. Both operands must be non-void pointers and having the same <id>
Expand Down Expand Up @@ -182,7 +184,7 @@ def SPIRV_InBoundsPtrAccessChainOp : SPIRV_Op<"InBoundsPtrAccessChain", [Pure]>

// -----

def SPIRV_LoadOp : SPIRV_Op<"Load", []> {
def SPIRV_LoadOp : SPIRV_Op<"Load", [DeclareOpInterfaceMethods<AlignmentAttrOpInterface>]> {
let summary = "Load through a pointer.";

let description = [{
Expand Down Expand Up @@ -310,7 +312,7 @@ def SPIRV_PtrAccessChainOp : SPIRV_Op<"PtrAccessChain", [Pure]> {

// -----

def SPIRV_StoreOp : SPIRV_Op<"Store", []> {
def SPIRV_StoreOp : SPIRV_Op<"Store", [DeclareOpInterfaceMethods<AlignmentAttrOpInterface>]> {
let summary = "Store through a pointer.";

let description = [{
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mlir/Dialect/SPIRV/Interfaces/SPIRVImageInterfaces.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/AlignmentAttrInterface.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/AlignmentAttrInterface.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/IndexingMapOpInterface.h"
Expand Down
35 changes: 27 additions & 8 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td"
include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td"
include "mlir/Dialect/Vector/IR/Vector.td"
include "mlir/Dialect/Vector/IR/VectorAttributes.td"
include "mlir/Interfaces/AlignmentAttrInterface.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/IndexingMapOpInterface.td"
Expand Down Expand Up @@ -1652,7 +1653,8 @@ def Vector_TransferWriteOp :

def Vector_LoadOp : Vector_Op<"load", [
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
]> {
let summary = "reads an n-D slice of memory into an n-D vector";
let description = [{
Expand Down Expand Up @@ -1769,7 +1771,8 @@ def Vector_LoadOp : Vector_Op<"load", [

def Vector_StoreOp : Vector_Op<"store", [
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
]> {
let summary = "writes an n-D vector to an n-D slice of memory";
let description = [{
Expand Down Expand Up @@ -1874,7 +1877,10 @@ def Vector_StoreOp : Vector_Op<"store", [
}

def Vector_MaskedLoadOp :
Vector_Op<"maskedload", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
Vector_Op<"maskedload", [
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
]>,
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
Expand Down Expand Up @@ -1966,7 +1972,10 @@ def Vector_MaskedLoadOp :
}

def Vector_MaskedStoreOp :
Vector_Op<"maskedstore", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
Vector_Op<"maskedstore", [
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
]>,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
Expand Down Expand Up @@ -2047,7 +2056,8 @@ def Vector_GatherOp :
Vector_Op<"gather", [
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
]>,
Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
Variadic<Index>:$offsets,
Expand Down Expand Up @@ -2150,7 +2160,10 @@ def Vector_GatherOp :
}

def Vector_ScatterOp :
Vector_Op<"scatter", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
Vector_Op<"scatter", [
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
]>,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$offsets,
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
Expand Down Expand Up @@ -2235,7 +2248,10 @@ def Vector_ScatterOp :
}

def Vector_ExpandLoadOp :
Vector_Op<"expandload", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
Vector_Op<"expandload", [
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
]>,
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
Variadic<Index>:$indices,
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
Expand Down Expand Up @@ -2323,7 +2339,10 @@ def Vector_ExpandLoadOp :
}

def Vector_CompressStoreOp :
Vector_Op<"compressstore", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
Vector_Op<"compressstore", [
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
]>,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
Expand Down
21 changes: 21 additions & 0 deletions mlir/include/mlir/Interfaces/AlignmentAttrInterface.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//===- AlignmentAttrInterface.h - Alignment attribute interface -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_INTERFACES_ALIGNMENTATTRINTERFACE_H
#define MLIR_INTERFACES_ALIGNMENTATTRINTERFACE_H

#include "mlir/IR/OpDefinition.h"
#include "llvm/Support/Alignment.h"

namespace mlir {
class MLIRContext;
} // namespace mlir

#include "mlir/Interfaces/AlignmentAttrInterface.h.inc"

#endif // MLIR_INTERFACES_ALIGNMENTATTRINTERFACE_H
55 changes: 55 additions & 0 deletions mlir/include/mlir/Interfaces/AlignmentAttrInterface.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//===- AlignmentAttrInterface.td - Alignment attribute interface -*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines an interface for operations that expose an optional
// alignment attribute.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_INTERFACES_ALIGNMENTATTRINTERFACE_TD
#define MLIR_INTERFACES_ALIGNMENTATTRINTERFACE_TD

include "mlir/IR/OpBase.td"

def AlignmentAttrOpInterface : OpInterface<"AlignmentAttrOpInterface"> {
let description = [{
An interface for operations that carry an optional alignment attribute and
want to expose it as an `llvm::MaybeAlign` helper.
}];

let cppNamespace = "::mlir";

let methods = [
InterfaceMethod<[{
Returns the alignment encoded on the operation as an `llvm::MaybeAlign`.
Operations providing a differently named accessor can override the
default implementation.
}],
"::llvm::MaybeAlign",
"getMaybeAlign",
(ins),
[{
auto alignmentOpt = $_op.getAlignment();
if (!alignmentOpt)
return ::llvm::MaybeAlign();
return ::llvm::MaybeAlign(static_cast<uint64_t>(*alignmentOpt));
}]
>
];

let extraTraitClassDeclaration = [{
::llvm::MaybeAlign getMaybeAlign() {
auto alignmentOpt = (*static_cast<ConcreteOp *>(this)).getAlignment();
if (!alignmentOpt)
return ::llvm::MaybeAlign();
return ::llvm::MaybeAlign(static_cast<uint64_t>(*alignmentOpt));
}
}];
}

#endif // MLIR_INTERFACES_ALIGNMENTATTRINTERFACE_TD
1 change: 1 addition & 0 deletions mlir/include/mlir/Interfaces/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_mlir_interface(AlignmentAttrInterface)
add_mlir_interface(CallInterfaces)
add_mlir_interface(CastInterfaces)
add_mlir_interface(ControlFlowInterfaces)
Expand Down
17 changes: 12 additions & 5 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/Support/Alignment.h"
#include "llvm/Support/Casting.h"

#include <optional>
Expand Down Expand Up @@ -248,7 +249,9 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {

// Resolve alignment.
// Explicit alignment takes priority over use-vector-alignment.
unsigned align = loadOrStoreOp.getAlignment().value_or(0);
unsigned align = 0;
if (llvm::MaybeAlign maybeAlign = loadOrStoreOp.getMaybeAlign())
align = maybeAlign->value();
if (!align &&
failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy,
memRefTy, align, useVectorAlignment)))
Expand Down Expand Up @@ -301,7 +304,9 @@ class VectorGatherOpConversion

// Resolve alignment.
// Explicit alignment takes priority over use-vector-alignment.
unsigned align = gather.getAlignment().value_or(0);
unsigned align = 0;
if (llvm::MaybeAlign maybeAlign = gather.getMaybeAlign())
align = maybeAlign->value();
if (!align &&
failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
memRefType, align, useVectorAlignment)))
Expand Down Expand Up @@ -358,7 +363,9 @@ class VectorScatterOpConversion

// Resolve alignment.
// Explicit alignment takes priority over use-vector-alignment.
unsigned align = scatter.getAlignment().value_or(0);
unsigned align = 0;
if (llvm::MaybeAlign maybeAlign = scatter.getMaybeAlign())
align = maybeAlign->value();
if (!align &&
failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
memRefType, align, useVectorAlignment)))
Expand Down Expand Up @@ -407,7 +414,7 @@ class VectorExpandLoadOpConversion
// From:
// https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics
// The pointer alignment defaults to 1.
uint64_t alignment = expand.getAlignment().value_or(1);
uint64_t alignment = expand.getMaybeAlign().valueOrOne().value();

rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru(),
Expand Down Expand Up @@ -435,7 +442,7 @@ class VectorCompressStoreOpConversion
// From:
// https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics
// The pointer alignment defaults to 1.
uint64_t alignment = compress.getAlignment().value_or(1);
uint64_t alignment = compress.getMaybeAlign().valueOrOne().value();

rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
compress, adaptor.getValueToStore(), ptr, adaptor.getMask(), alignment);
Expand Down
Loading