Skip to content

Commit 7b467bc

Browse files
amd-eochoalokuhar
andauthored
[mlir][vector] Add alignment attribute to vector operations. (llvm#152507)
Following llvm#144344, llvm#152207, llvm#151690, this PR adds the alignment attribute to the following operations in the vector dialect: * `compressstore` * `expandload` * `vector.scatter` * `vector.gather` --------- Co-authored-by: Jakub Kuderski <[email protected]>
1 parent 374db67 commit 7b467bc

File tree

2 files changed

+163
-20
lines changed

2 files changed

+163
-20
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 87 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1714,7 +1714,6 @@ def Vector_LoadOp : Vector_Op<"load", [
17141714
load operation. It must be a positive power of 2. The operation must access
17151715
memory at an address aligned to this boundary. Violations may lead to
17161716
architecture-specific faults or performance penalties.
1717-
A value of 0 indicates no specific alignment requirement.
17181717
}];
17191718

17201719
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
@@ -1830,7 +1829,6 @@ def Vector_StoreOp : Vector_Op<"store", [
18301829
store operation. It must be a positive power of 2. The operation must access
18311830
memory at an address aligned to this boundary. Violations may lead to
18321831
architecture-specific faults or performance penalties.
1833-
A value of 0 indicates no specific alignment requirement.
18341832
}];
18351833

18361834
let arguments = (ins
@@ -1919,7 +1917,6 @@ def Vector_MaskedLoadOp :
19191917
load operation. It must be a positive power of 2. The operation must access
19201918
memory at an address aligned to this boundary. Violations may lead to
19211919
architecture-specific faults or performance penalties.
1922-
A value of 0 indicates no specific alignment requirement.
19231920
}];
19241921
let extraClassDeclaration = [{
19251922
MemRefType getMemRefType() {
@@ -2012,7 +2009,6 @@ def Vector_MaskedStoreOp :
20122009
store operation. It must be a positive power of 2. The operation must access
20132010
memory at an address aligned to this boundary. Violations may lead to
20142011
architecture-specific faults or performance penalties.
2015-
A value of 0 indicates no specific alignment requirement.
20162012
}];
20172013
let extraClassDeclaration = [{
20182014
MemRefType getMemRefType() {
@@ -2054,7 +2050,9 @@ def Vector_GatherOp :
20542050
Variadic<Index>:$indices,
20552051
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
20562052
VectorOfNonZeroRankOf<[I1]>:$mask,
2057-
AnyVectorOfNonZeroRank:$pass_thru)>,
2053+
AnyVectorOfNonZeroRank:$pass_thru,
2054+
ConfinedAttr<OptionalAttr<I64Attr>,
2055+
[AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)>,
20582056
Results<(outs AnyVectorOfNonZeroRank:$result)> {
20592057

20602058
let summary = [{
@@ -2096,6 +2094,15 @@ def Vector_GatherOp :
20962094
comes from the pass-through vector regardless of the index, and the index is
20972095
allowed to be out-of-bounds.
20982096

2097+
The gather operation can be used directly where applicable, or can be used
2098+
during progressively lowering to bring other memory operations closer to
2099+
hardware ISA support for a gather.
2100+
2101+
An optional `alignment` attribute allows to specify the byte alignment of the
2102+
gather operation. It must be a positive power of 2. The operation must access
2103+
memory at an address aligned to this boundary. Violations may lead to
2104+
architecture-specific faults or performance penalties.
2105+
20992106
Examples:
21002107

21012108
```mlir
@@ -2124,6 +2131,20 @@ def Vector_GatherOp :
21242131
"`into` type($result)";
21252132
let hasCanonicalizer = 1;
21262133
let hasVerifier = 1;
2134+
2135+
let builders = [
2136+
OpBuilder<(ins "VectorType":$resultType,
2137+
"Value":$base,
2138+
"ValueRange":$indices,
2139+
"Value":$index_vec,
2140+
"Value":$mask,
2141+
"Value":$passthrough,
2142+
CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{
2143+
return build($_builder, $_state, resultType, base, indices, index_vec, mask, passthrough,
2144+
alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
2145+
nullptr);
2146+
}]>
2147+
];
21272148
}
21282149

21292150
def Vector_ScatterOp :
@@ -2132,7 +2153,9 @@ def Vector_ScatterOp :
21322153
Variadic<Index>:$indices,
21332154
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
21342155
VectorOfNonZeroRankOf<[I1]>:$mask,
2135-
AnyVectorOfNonZeroRank:$valueToStore)> {
2156+
AnyVectorOfNonZeroRank:$valueToStore,
2157+
ConfinedAttr<OptionalAttr<I64Attr>,
2158+
[AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)> {
21362159

21372160
let summary = [{
21382161
scatters elements from a vector into memory as defined by an index vector
@@ -2166,6 +2189,11 @@ def Vector_ScatterOp :
21662189
correspond to those of the `llvm.masked.scatter`
21672190
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-scatter-intrinsics).
21682191

2192+
An optional `alignment` attribute allows to specify the byte alignment of the
2193+
scatter operation. It must be a positive power of 2. The operation must access
2194+
memory at an address aligned to this boundary. Violations may lead to
2195+
architecture-specific faults or performance penalties.
2196+
21692197
Examples:
21702198

21712199
```mlir
@@ -2190,14 +2218,29 @@ def Vector_ScatterOp :
21902218
"type($index_vec) `,` type($mask) `,` type($valueToStore)";
21912219
let hasCanonicalizer = 1;
21922220
let hasVerifier = 1;
2221+
2222+
let builders = [
2223+
OpBuilder<(ins "Value":$base,
2224+
"ValueRange":$indices,
2225+
"Value":$index_vec,
2226+
"Value":$mask,
2227+
"Value":$valueToStore,
2228+
CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">: $alignment), [{
2229+
return build($_builder, $_state, base, indices, index_vec, mask, valueToStore,
2230+
alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
2231+
nullptr);
2232+
}]>
2233+
];
21932234
}
21942235

21952236
def Vector_ExpandLoadOp :
21962237
Vector_Op<"expandload">,
21972238
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
21982239
Variadic<Index>:$indices,
21992240
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
2200-
AnyVectorOfNonZeroRank:$pass_thru)>,
2241+
AnyVectorOfNonZeroRank:$pass_thru,
2242+
ConfinedAttr<OptionalAttr<I64Attr>,
2243+
[AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)>,
22012244
Results<(outs AnyVectorOfNonZeroRank:$result)> {
22022245

22032246
let summary = "reads elements from memory and spreads them into a vector as defined by a mask";
@@ -2229,6 +2272,11 @@ def Vector_ExpandLoadOp :
22292272
correspond to those of the `llvm.masked.expandload`
22302273
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics).
22312274

2275+
An optional `alignment` attribute allows to specify the byte alignment of the
2276+
load operation. It must be a positive power of 2. The operation must access
2277+
memory at an address aligned to this boundary. Violations may lead to
2278+
architecture-specific faults or performance penalties.
2279+
22322280
Note, at the moment this Op is only available for fixed-width vectors.
22332281

22342282
Examples:
@@ -2259,14 +2307,29 @@ def Vector_ExpandLoadOp :
22592307
"type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
22602308
let hasCanonicalizer = 1;
22612309
let hasVerifier = 1;
2310+
2311+
let builders = [
2312+
OpBuilder<(ins "VectorType":$resultType,
2313+
"Value":$base,
2314+
"ValueRange":$indices,
2315+
"Value":$mask,
2316+
"Value":$passthrough,
2317+
CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{
2318+
return build($_builder, $_state, resultType, base, indices, mask, passthrough,
2319+
alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
2320+
nullptr);
2321+
}]>
2322+
];
22622323
}
22632324

22642325
def Vector_CompressStoreOp :
22652326
Vector_Op<"compressstore">,
22662327
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
22672328
Variadic<Index>:$indices,
22682329
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
2269-
AnyVectorOfNonZeroRank:$valueToStore)> {
2330+
AnyVectorOfNonZeroRank:$valueToStore,
2331+
ConfinedAttr<OptionalAttr<I64Attr>,
2332+
[AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)> {
22702333

22712334
let summary = "writes elements selectively from a vector as defined by a mask";
22722335

@@ -2297,6 +2360,11 @@ def Vector_CompressStoreOp :
22972360
correspond to those of the `llvm.masked.compressstore`
22982361
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics).
22992362

2363+
An optional `alignment` attribute allows to specify the byte alignment of the
2364+
store operation. It must be a positive power of 2. The operation must access
2365+
memory at an address aligned to this boundary. Violations may lead to
2366+
architecture-specific faults or performance penalties.
2367+
23002368
Note, at the moment this Op is only available for fixed-width vectors.
23012369

23022370
Examples:
@@ -2325,6 +2393,17 @@ def Vector_CompressStoreOp :
23252393
"type($base) `,` type($mask) `,` type($valueToStore)";
23262394
let hasCanonicalizer = 1;
23272395
let hasVerifier = 1;
2396+
let builders = [
2397+
OpBuilder<(ins "Value":$base,
2398+
"ValueRange":$indices,
2399+
"Value":$mask,
2400+
"Value":$valueToStore,
2401+
CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{
2402+
return build($_builder, $_state, base, indices, valueToStore, mask,
2403+
alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
2404+
nullptr);
2405+
}]>
2406+
];
23282407
}
23292408

23302409
def Vector_ShapeCastOp :

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 76 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1309,15 +1309,15 @@ func.func @store_memref_index_mismatch(%base : memref<?xf32>, %value : vector<16
13091309
// vector.maskedload
13101310
//===----------------------------------------------------------------------===//
13111311

1312-
func.func @maskedload_negative_alignment(%base: memref<4xi32>, %mask: vector<32xi1>, %pass: vector<1xi32>, %index: index) {
1312+
func.func @maskedload_nonpositive_alignment(%base: memref<4xi32>, %mask: vector<32xi1>, %pass: vector<1xi32>, %index: index) {
13131313
// expected-error@below {{'vector.maskedload' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1314-
%val = vector.maskedload %base[%index], %mask, %pass { alignment = -1 } : memref<4xi32>, vector<32xi1>, vector<1xi32> into vector<1xi32>
1314+
%val = vector.maskedload %base[%index], %mask, %pass { alignment = 0 } : memref<4xi32>, vector<32xi1>, vector<1xi32> into vector<1xi32>
13151315
return
13161316
}
13171317

13181318
// -----
13191319

1320-
func.func @maskedload_nonpoweroftwo_alignment(%base: memref<4xi32>, %mask: vector<32xi1>, %pass: vector<1xi32>, %index: index) {
1320+
func.func @maskedload_non_power_of_2_alignment(%base: memref<4xi32>, %mask: vector<32xi1>, %pass: vector<1xi32>, %index: index) {
13211321
// expected-error@below {{'vector.maskedload' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
13221322
%val = vector.maskedload %base[%index], %mask, %pass { alignment = 3 } : memref<4xi32>, vector<32xi1>, vector<1xi32> into vector<1xi32>
13231323
return
@@ -1360,15 +1360,15 @@ func.func @maskedload_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>
13601360
// vector.maskedstore
13611361
//===----------------------------------------------------------------------===//
13621362

1363-
func.func @maskedstore_negative_alignment(%base: memref<4xi32>, %mask: vector<32xi1>, %value: vector<1xi32>, %index: index) {
1363+
func.func @maskedstore_nonpositive_alignment(%base: memref<4xi32>, %mask: vector<32xi1>, %value: vector<1xi32>, %index: index) {
13641364
// expected-error@below {{'vector.maskedstore' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1365-
vector.maskedstore %base[%index], %mask, %value { alignment = -1 } : memref<4xi32>, vector<32xi1>, vector<1xi32> into vector<1xi32>
1365+
vector.maskedstore %base[%index], %mask, %value { alignment = 0 } : memref<4xi32>, vector<32xi1>, vector<1xi32> into vector<1xi32>
13661366
return
13671367
}
13681368

13691369
// -----
13701370

1371-
func.func @maskedstore_nonpoweroftwo_alignment(%base: memref<4xi32>, %mask: vector<32xi1>, %value: vector<1xi32>, %index: index) {
1371+
func.func @maskedstore_non_power_of_2_alignment(%base: memref<4xi32>, %mask: vector<32xi1>, %value: vector<1xi32>, %index: index) {
13721372
// expected-error@below {{'vector.maskedstore' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
13731373
vector.maskedstore %base[%index], %mask, %value { alignment = 3 } : memref<4xi32>, vector<32xi1>, vector<1xi32> into vector<1xi32>
13741374
return
@@ -1470,6 +1470,24 @@ func.func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector
14701470

14711471
// -----
14721472

1473+
func.func @gather_nonpositive_alignment(%base: memref<16xf32>, %indices: vector<16xi32>,
1474+
%mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0 : index) {
1475+
// expected-error@+2 {{'vector.gather' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1476+
%0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
1477+
{ alignment = 0 } : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
1478+
}
1479+
1480+
// -----
1481+
1482+
func.func @gather_non_power_of_two_alignment(%base: memref<16xf32>, %indices: vector<16xi32>,
1483+
%mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0 : index) {
1484+
// expected-error@+2 {{'vector.gather' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1485+
%0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
1486+
{ alignment = 3 } : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
1487+
}
1488+
1489+
// -----
1490+
14731491
func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
14741492
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
14751493
%c0 = arith.constant 0 : index
@@ -1531,6 +1549,24 @@ func.func @scatter_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi
15311549

15321550
// -----
15331551

1552+
func.func @scatter_nonpositive_alignment(%base: memref<?xf32>, %indices: vector<16xi32>,
1553+
%mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) {
1554+
// expected-error@+1 {{'vector.scatter' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1555+
vector.scatter %base[%c0][%indices], %mask, %value { alignment = 0 }
1556+
: memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
1557+
}
1558+
1559+
// -----
1560+
1561+
func.func @scatter_non_power_of_2_alignment(%base: memref<?xf32>, %indices: vector<16xi32>,
1562+
%mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) {
1563+
// expected-error@+1 {{'vector.scatter' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1564+
vector.scatter %base[%c0][%indices], %mask, %value { alignment = 3 }
1565+
: memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
1566+
}
1567+
1568+
// -----
1569+
15341570
func.func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
15351571
%c0 = arith.constant 0 : index
15361572
// expected-error@+1 {{'vector.expandload' op base and result element type should match}}
@@ -1571,6 +1607,20 @@ func.func @expand_memref_mismatch(%base: memref<?x?xf32>, %mask: vector<16xi1>,
15711607

15721608
// -----
15731609

1610+
func.func @expand_nonpositive_alignment(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0: index) {
1611+
// expected-error@+1 {{'vector.expandload' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1612+
%0 = vector.expandload %base[%c0], %mask, %pass_thru { alignment = 0 } : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
1613+
}
1614+
1615+
// -----
1616+
1617+
func.func @expand_non_power_of_2_alignment(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0: index) {
1618+
// expected-error@+1 {{'vector.expandload' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1619+
%0 = vector.expandload %base[%c0], %mask, %pass_thru { alignment = 3 } : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
1620+
}
1621+
1622+
// -----
1623+
15741624
func.func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
15751625
%c0 = arith.constant 0 : index
15761626
// expected-error@+1 {{'vector.compressstore' op base and valueToStore element type should match}}
@@ -1603,6 +1653,20 @@ func.func @compress_memref_mismatch(%base: memref<?x?xf32>, %mask: vector<16xi1>
16031653

16041654
// -----
16051655

1656+
func.func @compress_nonpositive_alignment(%base: memref<?xf32>, %mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) {
1657+
// expected-error @below {{'vector.compressstore' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1658+
vector.compressstore %base[%c0], %mask, %value { alignment = 0 } : memref<?xf32>, vector<16xi1>, vector<16xf32>
1659+
}
1660+
1661+
// -----
1662+
1663+
func.func @compress_non_power_of_2_alignment(%base: memref<?xf32>, %mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) {
1664+
// expected-error @below {{'vector.compressstore' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1665+
vector.compressstore %base[%c0], %mask, %value { alignment = 3 } : memref<?xf32>, vector<16xi1>, vector<16xf32>
1666+
}
1667+
1668+
// -----
1669+
16061670
func.func @scan_reduction_dim_constraint(%arg0: vector<2x3xi32>, %arg1: vector<3xi32>) -> vector<3xi32> {
16071671
// expected-error@+1 {{'vector.scan' op reduction dimension 5 has to be less than 2}}
16081672
%0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim = 5} :
@@ -1952,15 +2016,15 @@ func.func @vector_load(%src : memref<?xi8>) {
19522016

19532017
// -----
19542018

1955-
func.func @invalid_load_alignment(%memref: memref<4xi32>, %c0: index) {
2019+
func.func @load_nonpositive_alignment(%memref: memref<4xi32>, %c0: index) {
19562020
// expected-error @below {{'vector.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1957-
%val = vector.load %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32>
2021+
%val = vector.load %memref[%c0] { alignment = 0 } : memref<4xi32>, vector<4xi32>
19582022
return
19592023
}
19602024

19612025
// -----
19622026

1963-
func.func @invalid_load_alignment(%memref: memref<4xi32>, %c0: index) {
2027+
func.func @load_non_pow_of_2_alignment(%memref: memref<4xi32>, %c0: index) {
19642028
// expected-error @below {{'vector.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
19652029
%val = vector.load %memref[%c0] { alignment = 3 } : memref<4xi32>, vector<4xi32>
19662030
return
@@ -1981,15 +2045,15 @@ func.func @vector_store(%dest : memref<?xi8>, %vec : vector<16x16xi8>) {
19812045

19822046
// -----
19832047

1984-
func.func @invalid_store_alignment(%memref: memref<4xi32>, %val: vector<4xi32>, %c0: index) {
2048+
func.func @store_nonpositive_alignment(%memref: memref<4xi32>, %val: vector<4xi32>, %c0: index) {
19852049
// expected-error @below {{'vector.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1986-
vector.store %val, %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32>
2050+
vector.store %val, %memref[%c0] { alignment = 0 } : memref<4xi32>, vector<4xi32>
19872051
return
19882052
}
19892053

19902054
// -----
19912055

1992-
func.func @invalid_store_alignment(%memref: memref<4xi32>, %val: vector<4xi32>, %c0: index) {
2056+
func.func @store_non_pow_of_2_alignment(%memref: memref<4xi32>, %val: vector<4xi32>, %c0: index) {
19932057
// expected-error @below {{'vector.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
19942058
vector.store %val, %memref[%c0] { alignment = 3 } : memref<4xi32>, vector<4xi32>
19952059
return

0 commit comments

Comments
 (0)