Skip to content

Commit a0cc39a

Browse files
committed
[mlir][vector] Add alignment to vector operations.
This is a squash of PR llvm#152507
1 parent 4d4966d commit a0cc39a

File tree

2 files changed

+157
-10
lines changed

2 files changed

+157
-10
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2054,7 +2054,9 @@ def Vector_GatherOp :
20542054
Variadic<Index>:$indices,
20552055
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
20562056
VectorOfNonZeroRankOf<[I1]>:$mask,
2057-
AnyVectorOfNonZeroRank:$pass_thru)>,
2057+
AnyVectorOfNonZeroRank:$pass_thru,
2058+
ConfinedAttr<OptionalAttr<I64Attr>,
2059+
[AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)>,
20582060
Results<(outs AnyVectorOfNonZeroRank:$result)> {
20592061

20602062
let summary = [{
@@ -2085,6 +2087,12 @@ def Vector_GatherOp :
20852087
during progressively lowering to bring other memory operations closer to
20862088
hardware ISA support for a gather.
20872089

2090+
An optional `alignment` attribute allows to specify the byte alignment of the
2091+
scatter operation. It must be a positive power of 2. The operation must access
2092+
memory at an address aligned to this boundary. Violations may lead to
2093+
architecture-specific faults or performance penalties.
2094+
A value of 0 indicates no specific alignment requirement.
2095+
20882096
Examples:
20892097

20902098
```mlir
@@ -2111,6 +2119,20 @@ def Vector_GatherOp :
21112119
"`into` type($result)";
21122120
let hasCanonicalizer = 1;
21132121
let hasVerifier = 1;
2122+
2123+
let builders = [
2124+
OpBuilder<(ins "VectorType":$resultType,
2125+
"Value":$base,
2126+
"ValueRange":$indices,
2127+
"Value":$index_vec,
2128+
"Value":$mask,
2129+
"Value":$passthrough,
2130+
CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{
2131+
return build($_builder, $_state, resultType, base, indices, index_vec, mask, passthrough,
2132+
alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
2133+
nullptr);
2134+
}]>
2135+
];
21142136
}
21152137

21162138
def Vector_ScatterOp :
@@ -2119,7 +2141,9 @@ def Vector_ScatterOp :
21192141
Variadic<Index>:$indices,
21202142
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
21212143
VectorOfNonZeroRankOf<[I1]>:$mask,
2122-
AnyVectorOfNonZeroRank:$valueToStore)> {
2144+
AnyVectorOfNonZeroRank:$valueToStore,
2145+
ConfinedAttr<OptionalAttr<I64Attr>,
2146+
[AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)> {
21232147

21242148
let summary = [{
21252149
scatters elements from a vector into memory as defined by an index vector
@@ -2153,6 +2177,12 @@ def Vector_ScatterOp :
21532177
correspond to those of the `llvm.masked.scatter`
21542178
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-scatter-intrinsics).
21552179

2180+
An optional `alignment` attribute allows to specify the byte alignment of the
2181+
scatter operation. It must be a positive power of 2. The operation must access
2182+
memory at an address aligned to this boundary. Violations may lead to
2183+
architecture-specific faults or performance penalties.
2184+
A value of 0 indicates no specific alignment requirement.
2185+
21562186
Examples:
21572187

21582188
```mlir
@@ -2177,14 +2207,29 @@ def Vector_ScatterOp :
21772207
"type($index_vec) `,` type($mask) `,` type($valueToStore)";
21782208
let hasCanonicalizer = 1;
21792209
let hasVerifier = 1;
2210+
2211+
let builders = [
2212+
OpBuilder<(ins "Value":$base,
2213+
"ValueRange":$indices,
2214+
"Value":$index_vec,
2215+
"Value":$mask,
2216+
"Value":$valueToStore,
2217+
CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">: $alignment), [{
2218+
return build($_builder, $_state, base, indices, index_vec, mask, valueToStore,
2219+
alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
2220+
nullptr);
2221+
}]>
2222+
];
21802223
}
21812224

21822225
def Vector_ExpandLoadOp :
21832226
Vector_Op<"expandload">,
21842227
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
21852228
Variadic<Index>:$indices,
21862229
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
2187-
AnyVectorOfNonZeroRank:$pass_thru)>,
2230+
AnyVectorOfNonZeroRank:$pass_thru,
2231+
ConfinedAttr<OptionalAttr<I64Attr>,
2232+
[AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)>,
21882233
Results<(outs AnyVectorOfNonZeroRank:$result)> {
21892234

21902235
let summary = "reads elements from memory and spreads them into a vector as defined by a mask";
@@ -2216,6 +2261,12 @@ def Vector_ExpandLoadOp :
22162261
correspond to those of the `llvm.masked.expandload`
22172262
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics).
22182263

2264+
An optional `alignment` attribute allows to specify the byte alignment of the
2265+
load operation. It must be a positive power of 2. The operation must access
2266+
memory at an address aligned to this boundary. Violations may lead to
2267+
architecture-specific faults or performance penalties.
2268+
A value of 0 indicates no specific alignment requirement.
2269+
22192270
Note, at the moment this Op is only available for fixed-width vectors.
22202271

22212272
Examples:
@@ -2246,14 +2297,29 @@ def Vector_ExpandLoadOp :
22462297
"type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
22472298
let hasCanonicalizer = 1;
22482299
let hasVerifier = 1;
2300+
2301+
let builders = [
2302+
OpBuilder<(ins "VectorType":$resultType,
2303+
"Value":$base,
2304+
"ValueRange":$indices,
2305+
"Value":$mask,
2306+
"Value":$passthrough,
2307+
CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{
2308+
return build($_builder, $_state, resultType, base, indices, mask, passthrough,
2309+
alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
2310+
nullptr);
2311+
}]>
2312+
];
22492313
}
22502314

22512315
def Vector_CompressStoreOp :
22522316
Vector_Op<"compressstore">,
22532317
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
22542318
Variadic<Index>:$indices,
22552319
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
2256-
AnyVectorOfNonZeroRank:$valueToStore)> {
2320+
AnyVectorOfNonZeroRank:$valueToStore,
2321+
ConfinedAttr<OptionalAttr<I64Attr>,
2322+
[AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)> {
22572323

22582324
let summary = "writes elements selectively from a vector as defined by a mask";
22592325

@@ -2284,6 +2350,12 @@ def Vector_CompressStoreOp :
22842350
correspond to those of the `llvm.masked.compressstore`
22852351
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics).
22862352

2353+
An optional `alignment` attribute allows to specify the byte alignment of the
2354+
store operation. It must be a positive power of 2. The operation must access
2355+
memory at an address aligned to this boundary. Violations may lead to
2356+
architecture-specific faults or performance penalties.
2357+
A value of 0 indicates no specific alignment requirement.
2358+
22872359
Note, at the moment this Op is only available for fixed-width vectors.
22882360

22892361
Examples:
@@ -2312,6 +2384,17 @@ def Vector_CompressStoreOp :
23122384
"type($base) `,` type($mask) `,` type($valueToStore)";
23132385
let hasCanonicalizer = 1;
23142386
let hasVerifier = 1;
2387+
let builders = [
2388+
OpBuilder<(ins "Value":$base,
2389+
"ValueRange":$indices,
2390+
"Value":$mask,
2391+
"Value":$valueToStore,
2392+
CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{
2393+
return build($_builder, $_state, base, indices, valueToStore, mask,
2394+
alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
2395+
nullptr);
2396+
}]>
2397+
];
23152398
}
23162399

23172400
def Vector_ShapeCastOp :

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1317,7 +1317,7 @@ func.func @maskedload_negative_alignment(%base: memref<4xi32>, %mask: vector<32x
13171317

13181318
// -----
13191319

1320-
func.func @maskedload_nonpoweroftwo_alignment(%base: memref<4xi32>, %mask: vector<32xi1>, %pass: vector<1xi32>, %index: index) {
1320+
func.func @maskedload_non_power_of_2_alignment(%base: memref<4xi32>, %mask: vector<32xi1>, %pass: vector<1xi32>, %index: index) {
13211321
// expected-error@below {{'vector.maskedload' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
13221322
%val = vector.maskedload %base[%index], %mask, %pass { alignment = 3 } : memref<4xi32>, vector<32xi1>, vector<1xi32> into vector<1xi32>
13231323
return
@@ -1368,7 +1368,7 @@ func.func @maskedstore_negative_alignment(%base: memref<4xi32>, %mask: vector<32
13681368

13691369
// -----
13701370

1371-
func.func @maskedstore_nonpoweroftwo_alignment(%base: memref<4xi32>, %mask: vector<32xi1>, %value: vector<1xi32>, %index: index) {
1371+
func.func @maskedstore_non_power_of_2_alignment(%base: memref<4xi32>, %mask: vector<32xi1>, %value: vector<1xi32>, %index: index) {
13721372
// expected-error@below {{'vector.maskedstore' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
13731373
vector.maskedstore %base[%index], %mask, %value { alignment = 3 } : memref<4xi32>, vector<32xi1>, vector<1xi32> into vector<1xi32>
13741374
return
@@ -1470,6 +1470,24 @@ func.func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector
14701470

14711471
// -----
14721472

1473+
func.func @gather_negative_alignment(%base: memref<16xf32>, %indices: vector<16xi32>,
1474+
%mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0 : index) {
1475+
// expected-error@+2 {{'vector.gather' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1476+
%0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
1477+
{ alignment = -1 } : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
1478+
}
1479+
1480+
// -----
1481+
1482+
func.func @gather_non_power_of_two_alignment(%base: memref<16xf32>, %indices: vector<16xi32>,
1483+
%mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0 : index) {
1484+
// expected-error@+2 {{'vector.gather' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1485+
%0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
1486+
{ alignment = 3 } : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
1487+
}
1488+
1489+
// -----
1490+
14731491
func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
14741492
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
14751493
%c0 = arith.constant 0 : index
@@ -1531,6 +1549,24 @@ func.func @scatter_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi
15311549

15321550
// -----
15331551

1552+
func.func @scatter_negative_alignment(%base: memref<?xf32>, %indices: vector<16xi32>,
1553+
%mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) {
1554+
// expected-error@+1 {{'vector.scatter' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1555+
vector.scatter %base[%c0][%indices], %mask, %value { alignment = -1 }
1556+
: memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
1557+
}
1558+
1559+
// -----
1560+
1561+
func.func @scatter_non_power_of_2_alignment(%base: memref<?xf32>, %indices: vector<16xi32>,
1562+
%mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) {
1563+
// expected-error@+1 {{'vector.scatter' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1564+
vector.scatter %base[%c0][%indices], %mask, %value { alignment = 3 }
1565+
: memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
1566+
}
1567+
1568+
// -----
1569+
15341570
func.func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
15351571
%c0 = arith.constant 0 : index
15361572
// expected-error@+1 {{'vector.expandload' op base and result element type should match}}
@@ -1571,6 +1607,20 @@ func.func @expand_memref_mismatch(%base: memref<?x?xf32>, %mask: vector<16xi1>,
15711607

15721608
// -----
15731609

1610+
func.func @expand_negative_alignment(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0: index) {
1611+
// expected-error@+1 {{'vector.expandload' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1612+
%0 = vector.expandload %base[%c0], %mask, %pass_thru { alignment = -1 } : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
1613+
}
1614+
1615+
// -----
1616+
1617+
func.func @expand_non_power_of_2_alignment(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0: index) {
1618+
// expected-error@+1 {{'vector.expandload' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1619+
%0 = vector.expandload %base[%c0], %mask, %pass_thru { alignment = 3 } : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
1620+
}
1621+
1622+
// -----
1623+
15741624
func.func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
15751625
%c0 = arith.constant 0 : index
15761626
// expected-error@+1 {{'vector.compressstore' op base and valueToStore element type should match}}
@@ -1603,6 +1653,20 @@ func.func @compress_memref_mismatch(%base: memref<?x?xf32>, %mask: vector<16xi1>
16031653

16041654
// -----
16051655

1656+
func.func @compress_negative_alignment(%base: memref<?xf32>, %mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) {
1657+
// expected-error @below {{'vector.compressstore' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1658+
vector.compressstore %base[%c0], %mask, %value { alignment = -1 } : memref<?xf32>, vector<16xi1>, vector<16xf32>
1659+
}
1660+
1661+
// -----
1662+
1663+
func.func @compress_non_power_of_2_alignment(%base: memref<?xf32>, %mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) {
1664+
// expected-error @below {{'vector.compressstore' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1665+
vector.compressstore %base[%c0], %mask, %value { alignment = 3 } : memref<?xf32>, vector<16xi1>, vector<16xf32>
1666+
}
1667+
1668+
// -----
1669+
16061670
func.func @scan_reduction_dim_constraint(%arg0: vector<2x3xi32>, %arg1: vector<3xi32>) -> vector<3xi32> {
16071671
// expected-error@+1 {{'vector.scan' op reduction dimension 5 has to be less than 2}}
16081672
%0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim = 5} :
@@ -1952,15 +2016,15 @@ func.func @vector_load(%src : memref<?xi8>) {
19522016

19532017
// -----
19542018

1955-
func.func @invalid_load_alignment(%memref: memref<4xi32>, %c0: index) {
2019+
func.func @load_negative_alignment(%memref: memref<4xi32>, %c0: index) {
19562020
// expected-error @below {{'vector.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
19572021
%val = vector.load %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32>
19582022
return
19592023
}
19602024

19612025
// -----
19622026

1963-
func.func @invalid_load_alignment(%memref: memref<4xi32>, %c0: index) {
2027+
func.func @load_non_pow_of_2_alignment(%memref: memref<4xi32>, %c0: index) {
19642028
// expected-error @below {{'vector.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
19652029
%val = vector.load %memref[%c0] { alignment = 3 } : memref<4xi32>, vector<4xi32>
19662030
return
@@ -1981,15 +2045,15 @@ func.func @vector_store(%dest : memref<?xi8>, %vec : vector<16x16xi8>) {
19812045

19822046
// -----
19832047

1984-
func.func @invalid_store_alignment(%memref: memref<4xi32>, %val: vector<4xi32>, %c0: index) {
2048+
func.func @store_negative_alignment(%memref: memref<4xi32>, %val: vector<4xi32>, %c0: index) {
19852049
// expected-error @below {{'vector.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
19862050
vector.store %val, %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32>
19872051
return
19882052
}
19892053

19902054
// -----
19912055

1992-
func.func @invalid_store_alignment(%memref: memref<4xi32>, %val: vector<4xi32>, %c0: index) {
2056+
func.func @store_non_pow_of_2_alignment(%memref: memref<4xi32>, %val: vector<4xi32>, %c0: index) {
19932057
// expected-error @below {{'vector.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
19942058
vector.store %val, %memref[%c0] { alignment = 3 } : memref<4xi32>, vector<4xi32>
19952059
return

0 commit comments

Comments
 (0)