Skip to content

Commit 16dbbb4

Browse files
committed
[mlir] Add AlignmentAttrOpInterface for unified alignment attribute handling
Introduce a common interface for operations with alignment attributes across MemRef, Vector, and SPIRV dialects. The interface exposes getMaybeAlign() to retrieve alignment as llvm::MaybeAlign.
1 parent 8cba910 commit 16dbbb4

File tree

13 files changed

+168
-20
lines changed

13 files changed

+168
-20
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRef.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Dialect/Arith/IR/Arith.h"
1414
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
1515
#include "mlir/IR/Dialect.h"
16+
#include "mlir/Interfaces/AlignmentAttrInterface.h"
1617
#include "mlir/Interfaces/CallInterfaces.h"
1718
#include "mlir/Interfaces/CastInterfaces.h"
1819
#include "mlir/Interfaces/ControlFlowInterfaces.h"

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
include "mlir/Dialect/Arith/IR/ArithBase.td"
1313
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
14+
include "mlir/Interfaces/AlignmentAttrInterface.td"
1415
include "mlir/Interfaces/CastInterfaces.td"
1516
include "mlir/Interfaces/ControlFlowInterfaces.td"
1617
include "mlir/Interfaces/InferIntRangeInterface.td"
@@ -64,15 +65,15 @@ class AllocLikeOp<string mnemonic,
6465
list<Trait> traits = []> :
6566
MemRef_Op<mnemonic,
6667
!listconcat([
67-
AttrSizedOperandSegments
68+
AttrSizedOperandSegments,
69+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
6870
], traits)> {
6971

7072
let arguments = (ins Variadic<Index>:$dynamicSizes,
7173
// The symbolic operands (the ones in square brackets)
7274
// bind to the symbols of the memref's layout map.
7375
Variadic<Index>:$symbolOperands,
74-
ConfinedAttr<OptionalAttr<I64Attr>,
75-
[IntMinValue<0>]>:$alignment);
76+
OptionalAttr<IntValidAlignment<I64Attr>>:$alignment);
7677
let results = (outs Res<AnyMemRef, "",
7778
[MemAlloc<resource, 0, FullEffect>]>:$memref);
7879

@@ -232,7 +233,8 @@ def MemRef_AllocOp : AllocLikeOp<"alloc", DefaultResource, [
232233
//===----------------------------------------------------------------------===//
233234

234235

235-
def MemRef_ReallocOp : MemRef_Op<"realloc"> {
236+
def MemRef_ReallocOp : MemRef_Op<"realloc",
237+
[DeclareOpInterfaceMethods<AlignmentAttrOpInterface>]> {
236238
let summary = "memory reallocation operation";
237239
let description = [{
238240
The `realloc` operation changes the size of a memory region. The memory
@@ -298,8 +300,7 @@ def MemRef_ReallocOp : MemRef_Op<"realloc"> {
298300
let arguments = (ins Arg<MemRefRankOf<[AnyType], [1]>, "",
299301
[MemFreeAt<0, FullEffect>]>:$source,
300302
Optional<Index>:$dynamicResultSize,
301-
ConfinedAttr<OptionalAttr<I64Attr>,
302-
[IntMinValue<0>]>:$alignment);
303+
OptionalAttr<IntValidAlignment<I64Attr>>:$alignment);
303304

304305
let results = (outs Res<MemRefRankOf<[AnyType], [1]>, "",
305306
[MemAlloc<DefaultResource, 1,
@@ -1122,7 +1123,8 @@ def MemRef_GetGlobalOp : MemRef_Op<"get_global",
11221123
// GlobalOp
11231124
//===----------------------------------------------------------------------===//
11241125

1125-
def MemRef_GlobalOp : MemRef_Op<"global", [Symbol]> {
1126+
def MemRef_GlobalOp : MemRef_Op<"global", [Symbol,
1127+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>]> {
11261128
let summary = "declare or define a global memref variable";
11271129
let description = [{
11281130
The `memref.global` operation declares or defines a named global memref
@@ -1197,6 +1199,7 @@ def LoadOp : MemRef_Op<"load",
11971199
"memref", "result",
11981200
"::llvm::cast<MemRefType>($_self).getElementType()">,
11991201
MemRefsNormalizable,
1202+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>,
12001203
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
12011204
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
12021205
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
@@ -1964,6 +1967,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
19641967
"memref", "value",
19651968
"::llvm::cast<MemRefType>($_self).getElementType()">,
19661969
MemRefsNormalizable,
1970+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>,
19671971
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
19681972
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
19691973
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#ifndef MLIR_DIALECT_SPIRV_IR_COOPERATIVE_MATRIX_OPS
1717
#define MLIR_DIALECT_SPIRV_IR_COOPERATIVE_MATRIX_OPS
1818

19+
include "mlir/Interfaces/AlignmentAttrInterface.td"
20+
1921
//===----------------------------------------------------------------------===//
2022
// SPV_KHR_cooperative_matrix extension ops.
2123
//===----------------------------------------------------------------------===//
@@ -62,7 +64,7 @@ def SPIRV_KHRCooperativeMatrixLengthOp :
6264

6365
// -----
6466

65-
def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad", []> {
67+
def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad", [DeclareOpInterfaceMethods<AlignmentAttrOpInterface>]> {
6668
let summary = "Loads a cooperative matrix through a pointer";
6769

6870
let description = [{
@@ -148,7 +150,7 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
148150

149151
// -----
150152

151-
def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStore", []> {
153+
def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStore", [DeclareOpInterfaceMethods<AlignmentAttrOpInterface>]> {
152154
let summary = "Stores a cooperative matrix through a pointer";
153155

154156
let description = [{

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#define MLIR_DIALECT_SPIRV_IR_MEMORY_OPS
1616

1717
include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
18+
include "mlir/Interfaces/AlignmentAttrInterface.td"
19+
1820

1921
// -----
2022

@@ -79,7 +81,7 @@ def SPIRV_AccessChainOp : SPIRV_Op<"AccessChain", [Pure]> {
7981

8082
// -----
8183

82-
def SPIRV_CopyMemoryOp : SPIRV_Op<"CopyMemory", []> {
84+
def SPIRV_CopyMemoryOp : SPIRV_Op<"CopyMemory", [DeclareOpInterfaceMethods<AlignmentAttrOpInterface>]> {
8385
let summary = [{
8486
Copy from the memory pointed to by Source to the memory pointed to by
8587
Target. Both operands must be non-void pointers and having the same <id>
@@ -182,7 +184,7 @@ def SPIRV_InBoundsPtrAccessChainOp : SPIRV_Op<"InBoundsPtrAccessChain", [Pure]>
182184

183185
// -----
184186

185-
def SPIRV_LoadOp : SPIRV_Op<"Load", []> {
187+
def SPIRV_LoadOp : SPIRV_Op<"Load", [DeclareOpInterfaceMethods<AlignmentAttrOpInterface>]> {
186188
let summary = "Load through a pointer.";
187189

188190
let description = [{
@@ -310,7 +312,7 @@ def SPIRV_PtrAccessChainOp : SPIRV_Op<"PtrAccessChain", [Pure]> {
310312

311313
// -----
312314

313-
def SPIRV_StoreOp : SPIRV_Op<"Store", []> {
315+
def SPIRV_StoreOp : SPIRV_Op<"Store", [DeclareOpInterfaceMethods<AlignmentAttrOpInterface>]> {
314316
let summary = "Store through a pointer.";
315317

316318
let description = [{

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/Dialect/SPIRV/Interfaces/SPIRVImageInterfaces.h"
2121
#include "mlir/IR/BuiltinOps.h"
2222
#include "mlir/IR/OpImplementation.h"
23+
#include "mlir/Interfaces/AlignmentAttrInterface.h"
2324
#include "mlir/Interfaces/CallInterfaces.h"
2425
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2526
#include "mlir/Interfaces/FunctionInterfaces.h"

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/IR/Dialect.h"
2424
#include "mlir/IR/OpDefinition.h"
2525
#include "mlir/IR/PatternMatch.h"
26+
#include "mlir/Interfaces/AlignmentAttrInterface.h"
2627
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2728
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2829
#include "mlir/Interfaces/IndexingMapOpInterface.h"

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td"
1919
include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td"
2020
include "mlir/Dialect/Vector/IR/Vector.td"
2121
include "mlir/Dialect/Vector/IR/VectorAttributes.td"
22+
include "mlir/Interfaces/AlignmentAttrInterface.td"
2223
include "mlir/Interfaces/ControlFlowInterfaces.td"
2324
include "mlir/Interfaces/DestinationStyleOpInterface.td"
2425
include "mlir/Interfaces/IndexingMapOpInterface.td"
@@ -1652,7 +1653,8 @@ def Vector_TransferWriteOp :
16521653

16531654
def Vector_LoadOp : Vector_Op<"load", [
16541655
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
1655-
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
1656+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
1657+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
16561658
]> {
16571659
let summary = "reads an n-D slice of memory into an n-D vector";
16581660
let description = [{
@@ -1769,7 +1771,8 @@ def Vector_LoadOp : Vector_Op<"load", [
17691771

17701772
def Vector_StoreOp : Vector_Op<"store", [
17711773
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
1772-
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
1774+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
1775+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
17731776
]> {
17741777
let summary = "writes an n-D vector to an n-D slice of memory";
17751778
let description = [{
@@ -1874,7 +1877,10 @@ def Vector_StoreOp : Vector_Op<"store", [
18741877
}
18751878

18761879
def Vector_MaskedLoadOp :
1877-
Vector_Op<"maskedload", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
1880+
Vector_Op<"maskedload", [
1881+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
1882+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
1883+
]>,
18781884
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
18791885
Variadic<Index>:$indices,
18801886
VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -1966,7 +1972,10 @@ def Vector_MaskedLoadOp :
19661972
}
19671973

19681974
def Vector_MaskedStoreOp :
1969-
Vector_Op<"maskedstore", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
1975+
Vector_Op<"maskedstore", [
1976+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
1977+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
1978+
]>,
19701979
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
19711980
Variadic<Index>:$indices,
19721981
VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -2047,7 +2056,8 @@ def Vector_GatherOp :
20472056
Vector_Op<"gather", [
20482057
DeclareOpInterfaceMethods<MaskableOpInterface>,
20492058
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
2050-
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
2059+
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
2060+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
20512061
]>,
20522062
Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
20532063
Variadic<Index>:$offsets,
@@ -2150,7 +2160,10 @@ def Vector_GatherOp :
21502160
}
21512161

21522162
def Vector_ScatterOp :
2153-
Vector_Op<"scatter", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
2163+
Vector_Op<"scatter", [
2164+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
2165+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
2166+
]>,
21542167
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
21552168
Variadic<Index>:$offsets,
21562169
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
@@ -2235,7 +2248,10 @@ def Vector_ScatterOp :
22352248
}
22362249

22372250
def Vector_ExpandLoadOp :
2238-
Vector_Op<"expandload", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
2251+
Vector_Op<"expandload", [
2252+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
2253+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
2254+
]>,
22392255
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
22402256
Variadic<Index>:$indices,
22412257
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
@@ -2323,7 +2339,10 @@ def Vector_ExpandLoadOp :
23232339
}
23242340

23252341
def Vector_CompressStoreOp :
2326-
Vector_Op<"compressstore", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
2342+
Vector_Op<"compressstore", [
2343+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
2344+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
2345+
]>,
23272346
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
23282347
Variadic<Index>:$indices,
23292348
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- AlignmentAttrInterface.h - Alignment attribute interface -*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_INTERFACES_ALIGNMENTATTRINTERFACE_H
10+
#define MLIR_INTERFACES_ALIGNMENTATTRINTERFACE_H
11+
12+
#include "mlir/IR/OpDefinition.h"
13+
#include "llvm/Support/Alignment.h"
14+
15+
namespace mlir {
16+
class MLIRContext;
17+
} // namespace mlir
18+
19+
#include "mlir/Interfaces/AlignmentAttrInterface.h.inc"
20+
21+
#endif // MLIR_INTERFACES_ALIGNMENTATTRINTERFACE_H
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
//===- AlignmentAttrInterface.td - Alignment attribute interface -*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines an interface for operations that expose an optional
10+
// alignment attribute.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_INTERFACES_ALIGNMENTATTRINTERFACE_TD
15+
#define MLIR_INTERFACES_ALIGNMENTATTRINTERFACE_TD
16+
17+
include "mlir/IR/OpBase.td"
18+
19+
def AlignmentAttrOpInterface : OpInterface<"AlignmentAttrOpInterface"> {
20+
let description = [{
21+
An interface for operations that carry an optional alignment attribute and
22+
want to expose it as an `llvm::MaybeAlign` helper.
23+
}];
24+
25+
let cppNamespace = "::mlir";
26+
27+
let methods = [
28+
InterfaceMethod<[{
29+
Returns the alignment encoded on the operation as an `llvm::MaybeAlign`.
30+
Operations providing a differently named accessor can override the
31+
default implementation.
32+
}],
33+
"::llvm::MaybeAlign",
34+
"getMaybeAlign",
35+
(ins),
36+
[{
37+
// Defensive: trait implementations are expected to validate power-of-two
38+
// alignments, but we still guard against accidental misuse.
39+
auto alignmentOpt = $_op.getAlignment();
40+
if (!alignmentOpt || *alignmentOpt <= 0)
41+
return ::llvm::MaybeAlign();
42+
uint64_t value = static_cast<uint64_t>(*alignmentOpt);
43+
if (!::llvm::isPowerOf2_64(value))
44+
return ::llvm::MaybeAlign();
45+
return ::llvm::MaybeAlign(value);
46+
}]
47+
>
48+
];
49+
50+
let extraTraitClassDeclaration = [{
51+
::llvm::MaybeAlign getMaybeAlign() {
52+
// Defensive: trait implementations are expected to validate power-of-two
53+
// alignments, but we still guard against accidental misuse.
54+
auto alignmentOpt = (*static_cast<ConcreteOp *>(this)).getAlignment();
55+
if (!alignmentOpt || *alignmentOpt <= 0)
56+
return ::llvm::MaybeAlign();
57+
uint64_t value = static_cast<uint64_t>(*alignmentOpt);
58+
if (!::llvm::isPowerOf2_64(value))
59+
return ::llvm::MaybeAlign();
60+
return ::llvm::MaybeAlign(value);
61+
}
62+
}];
63+
}
64+
65+
#endif // MLIR_INTERFACES_ALIGNMENTATTRINTERFACE_TD

mlir/include/mlir/Interfaces/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
add_mlir_interface(AlignmentAttrInterface)
12
add_mlir_interface(CallInterfaces)
23
add_mlir_interface(CastInterfaces)
34
add_mlir_interface(ControlFlowInterfaces)

0 commit comments

Comments
 (0)