Skip to content

Commit 3f081df

Browse files
committed
[mlir] Introduce AlignmentAttrOpInterface to expose MaybeAlign
1 parent fe9fba8 commit 3f081df

File tree

11 files changed

+145
-20
lines changed

11 files changed

+145
-20
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRef.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Dialect/Arith/IR/Arith.h"
1414
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
1515
#include "mlir/IR/Dialect.h"
16+
#include "mlir/Interfaces/AlignmentAttrInterface.h"
1617
#include "mlir/Interfaces/CallInterfaces.h"
1718
#include "mlir/Interfaces/CastInterfaces.h"
1819
#include "mlir/Interfaces/ControlFlowInterfaces.h"

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
include "mlir/Dialect/Arith/IR/ArithBase.td"
1313
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
14+
include "mlir/Interfaces/AlignmentAttrInterface.td"
1415
include "mlir/Interfaces/CastInterfaces.td"
1516
include "mlir/Interfaces/ControlFlowInterfaces.td"
1617
include "mlir/Interfaces/InferIntRangeInterface.td"
@@ -64,7 +65,8 @@ class AllocLikeOp<string mnemonic,
6465
list<Trait> traits = []> :
6566
MemRef_Op<mnemonic,
6667
!listconcat([
67-
AttrSizedOperandSegments
68+
AttrSizedOperandSegments,
69+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
6870
], traits)> {
6971

7072
let arguments = (ins Variadic<Index>:$dynamicSizes,
@@ -232,7 +234,8 @@ def MemRef_AllocOp : AllocLikeOp<"alloc", DefaultResource, [
232234
//===----------------------------------------------------------------------===//
233235

234236

235-
def MemRef_ReallocOp : MemRef_Op<"realloc"> {
237+
def MemRef_ReallocOp : MemRef_Op<"realloc",
238+
[DeclareOpInterfaceMethods<AlignmentAttrOpInterface>]> {
236239
let summary = "memory reallocation operation";
237240
let description = [{
238241
The `realloc` operation changes the size of a memory region. The memory
@@ -1122,7 +1125,8 @@ def MemRef_GetGlobalOp : MemRef_Op<"get_global",
11221125
// GlobalOp
11231126
//===----------------------------------------------------------------------===//
11241127

1125-
def MemRef_GlobalOp : MemRef_Op<"global", [Symbol]> {
1128+
def MemRef_GlobalOp : MemRef_Op<"global", [Symbol,
1129+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>]> {
11261130
let summary = "declare or define a global memref variable";
11271131
let description = [{
11281132
The `memref.global` operation declares or defines a named global memref
@@ -1197,6 +1201,7 @@ def LoadOp : MemRef_Op<"load",
11971201
"memref", "result",
11981202
"::llvm::cast<MemRefType>($_self).getElementType()">,
11991203
MemRefsNormalizable,
1204+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>,
12001205
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
12011206
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
12021207
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
@@ -1964,6 +1969,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
19641969
"memref", "value",
19651970
"::llvm::cast<MemRefType>($_self).getElementType()">,
19661971
MemRefsNormalizable,
1972+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>,
19671973
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
19681974
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
19691975
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/IR/Dialect.h"
2424
#include "mlir/IR/OpDefinition.h"
2525
#include "mlir/IR/PatternMatch.h"
26+
#include "mlir/Interfaces/AlignmentAttrInterface.h"
2627
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2728
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2829
#include "mlir/Interfaces/IndexingMapOpInterface.h"

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td"
1919
include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td"
2020
include "mlir/Dialect/Vector/IR/Vector.td"
2121
include "mlir/Dialect/Vector/IR/VectorAttributes.td"
22+
include "mlir/Interfaces/AlignmentAttrInterface.td"
2223
include "mlir/Interfaces/ControlFlowInterfaces.td"
2324
include "mlir/Interfaces/DestinationStyleOpInterface.td"
2425
include "mlir/Interfaces/IndexingMapOpInterface.td"
@@ -1652,7 +1653,8 @@ def Vector_TransferWriteOp :
16521653

16531654
def Vector_LoadOp : Vector_Op<"load", [
16541655
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
1655-
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
1656+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
1657+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
16561658
]> {
16571659
let summary = "reads an n-D slice of memory into an n-D vector";
16581660
let description = [{
@@ -1769,7 +1771,8 @@ def Vector_LoadOp : Vector_Op<"load", [
17691771

17701772
def Vector_StoreOp : Vector_Op<"store", [
17711773
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
1772-
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
1774+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
1775+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
17731776
]> {
17741777
let summary = "writes an n-D vector to an n-D slice of memory";
17751778
let description = [{
@@ -1874,7 +1877,10 @@ def Vector_StoreOp : Vector_Op<"store", [
18741877
}
18751878

18761879
def Vector_MaskedLoadOp :
1877-
Vector_Op<"maskedload", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
1880+
Vector_Op<"maskedload", [
1881+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
1882+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
1883+
]>,
18781884
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
18791885
Variadic<Index>:$indices,
18801886
VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -1966,7 +1972,10 @@ def Vector_MaskedLoadOp :
19661972
}
19671973

19681974
def Vector_MaskedStoreOp :
1969-
Vector_Op<"maskedstore", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
1975+
Vector_Op<"maskedstore", [
1976+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
1977+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
1978+
]>,
19701979
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
19711980
Variadic<Index>:$indices,
19721981
VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -2047,7 +2056,8 @@ def Vector_GatherOp :
20472056
Vector_Op<"gather", [
20482057
DeclareOpInterfaceMethods<MaskableOpInterface>,
20492058
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
2050-
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
2059+
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
2060+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
20512061
]>,
20522062
Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
20532063
Variadic<Index>:$offsets,
@@ -2150,7 +2160,10 @@ def Vector_GatherOp :
21502160
}
21512161

21522162
def Vector_ScatterOp :
2153-
Vector_Op<"scatter", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
2163+
Vector_Op<"scatter", [
2164+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
2165+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
2166+
]>,
21542167
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
21552168
Variadic<Index>:$offsets,
21562169
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
@@ -2235,7 +2248,10 @@ def Vector_ScatterOp :
22352248
}
22362249

22372250
def Vector_ExpandLoadOp :
2238-
Vector_Op<"expandload", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
2251+
Vector_Op<"expandload", [
2252+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
2253+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
2254+
]>,
22392255
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
22402256
Variadic<Index>:$indices,
22412257
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
@@ -2323,7 +2339,10 @@ def Vector_ExpandLoadOp :
23232339
}
23242340

23252341
def Vector_CompressStoreOp :
2326-
Vector_Op<"compressstore", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
2342+
Vector_Op<"compressstore", [
2343+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
2344+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
2345+
]>,
23272346
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
23282347
Variadic<Index>:$indices,
23292348
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- AlignmentAttrInterface.h - Alignment attribute interface -*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_INTERFACES_ALIGNMENTATTRINTERFACE_H
10+
#define MLIR_INTERFACES_ALIGNMENTATTRINTERFACE_H
11+
12+
#include "mlir/IR/OpDefinition.h"
13+
#include "llvm/Support/Alignment.h"
14+
15+
namespace mlir {
16+
class MLIRContext;
17+
} // namespace mlir
18+
19+
#include "mlir/Interfaces/AlignmentAttrInterface.h.inc"
20+
21+
#endif // MLIR_INTERFACES_ALIGNMENTATTRINTERFACE_H
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
//===- AlignmentAttrInterface.td - Alignment attribute interface -*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines an interface for operations that expose an optional
10+
// alignment attribute.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_INTERFACES_ALIGNMENTATTRINTERFACE_TD
15+
#define MLIR_INTERFACES_ALIGNMENTATTRINTERFACE_TD
16+
17+
include "mlir/IR/OpBase.td"
18+
19+
def AlignmentAttrOpInterface : OpInterface<"AlignmentAttrOpInterface"> {
20+
let description = [{
21+
An interface for operations that carry an optional alignment attribute and
22+
want to expose it as an `llvm::MaybeAlign` helper.
23+
}];
24+
25+
let cppNamespace = "::mlir";
26+
27+
let methods = [
28+
InterfaceMethod<[{
29+
Returns the alignment encoded on the operation as an `llvm::MaybeAlign`.
30+
Operations providing a differently named accessor can override the
31+
default implementation.
32+
}],
33+
"::llvm::MaybeAlign",
34+
"getMaybeAlign",
35+
(ins),
36+
[{
37+
auto alignmentOpt = $_op.getAlignment();
38+
if (!alignmentOpt)
39+
return ::llvm::MaybeAlign();
40+
return ::llvm::MaybeAlign(static_cast<uint64_t>(*alignmentOpt));
41+
}]
42+
>
43+
];
44+
45+
let extraTraitClassDeclaration = [{
46+
::llvm::MaybeAlign getMaybeAlign() {
47+
auto alignmentOpt = (*static_cast<ConcreteOp *>(this)).getAlignment();
48+
if (!alignmentOpt)
49+
return ::llvm::MaybeAlign();
50+
return ::llvm::MaybeAlign(static_cast<uint64_t>(*alignmentOpt));
51+
}
52+
}];
53+
}
54+
55+
#endif // MLIR_INTERFACES_ALIGNMENTATTRINTERFACE_TD

mlir/include/mlir/Interfaces/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
add_mlir_interface(AlignmentAttrInterface)
12
add_mlir_interface(CallInterfaces)
23
add_mlir_interface(CastInterfaces)
34
add_mlir_interface(ControlFlowInterfaces)

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "mlir/Transforms/DialectConversion.h"
3131
#include "llvm/ADT/APFloat.h"
3232
#include "llvm/IR/LLVMContext.h"
33+
#include "llvm/Support/Alignment.h"
3334
#include "llvm/Support/Casting.h"
3435

3536
#include <optional>
@@ -248,7 +249,9 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
248249

249250
// Resolve alignment.
250251
// Explicit alignment takes priority over use-vector-alignment.
251-
unsigned align = loadOrStoreOp.getAlignment().value_or(0);
252+
unsigned align = 0;
253+
if (llvm::MaybeAlign maybeAlign = loadOrStoreOp.getMaybeAlign())
254+
align = maybeAlign->value();
252255
if (!align &&
253256
failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy,
254257
memRefTy, align, useVectorAlignment)))
@@ -301,7 +304,9 @@ class VectorGatherOpConversion
301304

302305
// Resolve alignment.
303306
// Explicit alignment takes priority over use-vector-alignment.
304-
unsigned align = gather.getAlignment().value_or(0);
307+
unsigned align = 0;
308+
if (llvm::MaybeAlign maybeAlign = gather.getMaybeAlign())
309+
align = maybeAlign->value();
305310
if (!align &&
306311
failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
307312
memRefType, align, useVectorAlignment)))
@@ -358,7 +363,9 @@ class VectorScatterOpConversion
358363

359364
// Resolve alignment.
360365
// Explicit alignment takes priority over use-vector-alignment.
361-
unsigned align = scatter.getAlignment().value_or(0);
366+
unsigned align = 0;
367+
if (llvm::MaybeAlign maybeAlign = scatter.getMaybeAlign())
368+
align = maybeAlign->value();
362369
if (!align &&
363370
failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
364371
memRefType, align, useVectorAlignment)))
@@ -407,7 +414,7 @@ class VectorExpandLoadOpConversion
407414
// From:
408415
// https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics
409416
// The pointer alignment defaults to 1.
410-
uint64_t alignment = expand.getAlignment().value_or(1);
417+
uint64_t alignment = expand.getMaybeAlign().valueOrOne().value();
411418

412419
rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
413420
expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru(),
@@ -435,7 +442,7 @@ class VectorCompressStoreOpConversion
435442
// From:
436443
// https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics
437444
// The pointer alignment defaults to 1.
438-
uint64_t alignment = compress.getAlignment().value_or(1);
445+
uint64_t alignment = compress.getMaybeAlign().valueOrOne().value();
439446

440447
rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
441448
compress, adaptor.getValueToStore(), ptr, adaptor.getMask(), alignment);

mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ struct VectorMaskedLoadOpConverter final
7575
[&](OpBuilder &builder, Location loc) {
7676
auto loadedValue = memref::LoadOp::create(
7777
builder, loc, base, indices, /*nontemporal=*/false,
78-
llvm::MaybeAlign(maskedLoadOp.getAlignment().value_or(0)));
78+
maskedLoadOp.getMaybeAlign());
7979
auto combinedValue =
8080
vector::InsertOp::create(builder, loc, loadedValue, iValue, i);
8181
scf::YieldOp::create(builder, loc, combinedValue.getResult());
@@ -143,9 +143,8 @@ struct VectorMaskedStoreOpConverter final
143143
auto ifOp = scf::IfOp::create(rewriter, loc, maskBit, /*else=*/false);
144144
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
145145
auto extractedValue = vector::ExtractOp::create(rewriter, loc, value, i);
146-
memref::StoreOp::create(
147-
rewriter, loc, extractedValue, base, indices, nontemporal,
148-
llvm::MaybeAlign(maskedStoreOp.getAlignment().value_or(0)));
146+
memref::StoreOp::create(rewriter, loc, extractedValue, base, indices,
147+
nontemporal, maskedStoreOp.getMaybeAlign());
149148

150149
rewriter.setInsertionPointAfter(ifOp);
151150
indices.back() =
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
//===- AlignmentAttrInterface.cpp -----------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Interfaces/AlignmentAttrInterface.h"
10+
11+
using namespace mlir;
12+
13+
#include "mlir/Interfaces/AlignmentAttrInterface.cpp.inc"

0 commit comments

Comments
 (0)