Skip to content

Commit a1672d7

Browse files
authored
[mlir][vector] Add alignment attribute to maskedload and maskedstore (#151690)
These commits continue the work done in #144344, of adding alignment attributes to operations in the vector and memref. These commits focus on adding the alignment attribute to the `maskedload` and `maskedstore` operations. The `VectorLoadConversion` pattern in VectorToLLVM is a template for `load`, `store`, `maskedload` and `maskedstore` operations. Having the alignment attribute in all these operations would allow for an easy way to propagate the alignment attribute from the vector dialect to the LLVM dialect. This patchset also includes changes to the conversion from VectorToLLVM to propagate the alignment attribute for the vector.{,masked}{load,store} operations.
1 parent fd41700 commit a1672d7

File tree

4 files changed

+160
-14
lines changed

4 files changed

+160
-14
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1729,18 +1729,18 @@ def Vector_LoadOp : Vector_Op<"load", [
17291729
"Value":$base,
17301730
"ValueRange":$indices,
17311731
CArg<"bool", "false">:$nontemporal,
1732-
CArg<"llvm::Align", "llvm::Align()">:$alignment), [{
1732+
CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{
17331733
return build($_builder, $_state, resultType, base, indices, nontemporal,
1734-
alignment != llvm::Align() ? $_builder.getI64IntegerAttr(alignment.value()) :
1734+
alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
17351735
nullptr);
17361736
}]>,
17371737
OpBuilder<(ins "TypeRange":$resultTypes,
17381738
"Value":$base,
17391739
"ValueRange":$indices,
17401740
CArg<"bool", "false">:$nontemporal,
1741-
CArg<"llvm::Align", "llvm::Align()">:$alignment), [{
1741+
CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{
17421742
return build($_builder, $_state, resultTypes, base, indices, nontemporal,
1743-
alignment != llvm::Align() ? $_builder.getI64IntegerAttr(alignment.value()) :
1743+
alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
17441744
nullptr);
17451745
}]>
17461746
];
@@ -1847,9 +1847,9 @@ def Vector_StoreOp : Vector_Op<"store", [
18471847
"Value":$base,
18481848
"ValueRange":$indices,
18491849
CArg<"bool", "false">:$nontemporal,
1850-
CArg<"llvm::Align", "llvm::Align()">:$alignment), [{
1850+
CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{
18511851
return build($_builder, $_state, valueToStore, base, indices, nontemporal,
1852-
alignment != llvm::Align() ? $_builder.getI64IntegerAttr(alignment.value()) :
1852+
alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
18531853
nullptr);
18541854
}]>
18551855
];
@@ -1876,7 +1876,9 @@ def Vector_MaskedLoadOp :
18761876
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
18771877
Variadic<Index>:$indices,
18781878
VectorOfNonZeroRankOf<[I1]>:$mask,
1879-
AnyVectorOfNonZeroRank:$pass_thru)>,
1879+
AnyVectorOfNonZeroRank:$pass_thru,
1880+
ConfinedAttr<OptionalAttr<I64Attr>,
1881+
[AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)>,
18801882
Results<(outs AnyVectorOfNonZeroRank:$result)> {
18811883

18821884
let summary = "loads elements from memory into a vector as defined by a mask vector";
@@ -1912,6 +1914,12 @@ def Vector_MaskedLoadOp :
19121914
%1 = vector.maskedload %base[%i, %j], %mask, %pass_thru
19131915
: memref<?x?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
19141916
```
1917+
1918+
An optional `alignment` attribute allows to specify the byte alignment of the
1919+
load operation. It must be a positive power of 2. The operation must access
1920+
memory at an address aligned to this boundary. Violations may lead to
1921+
architecture-specific faults or performance penalties.
1922+
A value of 0 indicates no specific alignment requirement.
19151923
}];
19161924
let extraClassDeclaration = [{
19171925
MemRefType getMemRefType() {
@@ -1932,14 +1940,39 @@ def Vector_MaskedLoadOp :
19321940
let hasCanonicalizer = 1;
19331941
let hasFolder = 1;
19341942
let hasVerifier = 1;
1943+
1944+
let builders = [
1945+
OpBuilder<(ins "VectorType":$resultType,
1946+
"Value":$base,
1947+
"ValueRange":$indices,
1948+
"Value":$mask,
1949+
"Value":$passthrough,
1950+
CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{
1951+
return build($_builder, $_state, resultType, base, indices, mask, passthrough,
1952+
alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
1953+
nullptr);
1954+
}]>,
1955+
OpBuilder<(ins "TypeRange":$resultTypes,
1956+
"Value":$base,
1957+
"ValueRange":$indices,
1958+
"Value":$mask,
1959+
"Value":$passthrough,
1960+
CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{
1961+
return build($_builder, $_state, resultTypes, base, indices, mask, passthrough,
1962+
alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
1963+
nullptr);
1964+
}]>
1965+
];
19351966
}
19361967

19371968
def Vector_MaskedStoreOp :
19381969
Vector_Op<"maskedstore">,
19391970
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
19401971
Variadic<Index>:$indices,
19411972
VectorOfNonZeroRankOf<[I1]>:$mask,
1942-
AnyVectorOfNonZeroRank:$valueToStore)> {
1973+
AnyVectorOfNonZeroRank:$valueToStore,
1974+
ConfinedAttr<OptionalAttr<I64Attr>,
1975+
[AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)> {
19431976

19441977
let summary = "stores elements from a vector into memory as defined by a mask vector";
19451978

@@ -1974,6 +2007,12 @@ def Vector_MaskedStoreOp :
19742007
vector.maskedstore %base[%i, %j], %mask, %value
19752008
: memref<?x?xf32>, vector<16xi1>, vector<16xf32>
19762009
```
2010+
2011+
An optional `alignment` attribute allows to specify the byte alignment of the
2012+
store operation. It must be a positive power of 2. The operation must access
2013+
memory at an address aligned to this boundary. Violations may lead to
2014+
architecture-specific faults or performance penalties.
2015+
A value of 0 indicates no specific alignment requirement.
19772016
}];
19782017
let extraClassDeclaration = [{
19792018
MemRefType getMemRefType() {
@@ -1992,6 +2031,18 @@ def Vector_MaskedStoreOp :
19922031
let hasCanonicalizer = 1;
19932032
let hasFolder = 1;
19942033
let hasVerifier = 1;
2034+
2035+
let builders = [
2036+
OpBuilder<(ins "Value":$base,
2037+
"ValueRange":$indices,
2038+
"Value":$mask,
2039+
"Value":$valueToStore,
2040+
CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{
2041+
return build($_builder, $_state, base, indices, mask, valueToStore,
2042+
alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
2043+
nullptr);
2044+
}]>
2045+
];
19952046
}
19962047

19972048
def Vector_GatherOp :

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,9 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
247247
MemRefType memRefTy = loadOrStoreOp.getMemRefType();
248248

249249
// Resolve alignment.
250-
unsigned align;
251-
if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy,
250+
unsigned align = loadOrStoreOp.getAlignment().value_or(0);
251+
if (!align &&
252+
failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy,
252253
memRefTy, align, useVectorAlignment)))
253254
return rewriter.notifyMatchFailure(loadOrStoreOp,
254255
"could not resolve alignment");

mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1679,6 +1679,16 @@ func.func @load_0d(%memref : memref<200x100xf32>, %i : index, %j : index) -> vec
16791679

16801680
// -----
16811681

1682+
func.func @load_with_alignment(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> {
1683+
%0 = vector.load %memref[%i, %j] { alignment = 8 } : memref<200x100xf32>, vector<8xf32>
1684+
return %0 : vector<8xf32>
1685+
}
1686+
1687+
// CHECK-LABEL: func @load_with_alignment
1688+
// CHECK: llvm.load {{.*}} {alignment = 8 : i64} : !llvm.ptr -> vector<8xf32>
1689+
1690+
// -----
1691+
16821692
//===----------------------------------------------------------------------===//
16831693
// vector.store
16841694
//===----------------------------------------------------------------------===//
@@ -1785,6 +1795,16 @@ func.func @store_0d(%memref : memref<200x100xf32>, %i : index, %j : index) {
17851795

17861796
// -----
17871797

1798+
func.func @store_with_alignment(%memref : memref<200x100xf32>, %i : index, %j : index, %val : vector<4xf32>) {
1799+
vector.store %val, %memref[%i, %j] {alignment = 8} : memref<200x100xf32>, vector<4xf32>
1800+
return
1801+
}
1802+
1803+
// CHECK-LABEL: func @store_with_alignment
1804+
// CHECK: llvm.store %{{.*}} {alignment = 8 : i64} : vector<4xf32>, !llvm.ptr
1805+
1806+
// -----
1807+
17881808
//===----------------------------------------------------------------------===//
17891809
// vector.maskedload
17901810
//===----------------------------------------------------------------------===//
@@ -1839,6 +1859,16 @@ func.func @masked_load_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[16]
18391859

18401860
// -----
18411861

1862+
func.func @masked_load_with_alignment(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>, %arg3: index) -> vector<16xf32> {
1863+
%0 = vector.maskedload %arg0[%arg3], %arg1, %arg2 { alignment = 2 } : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
1864+
return %0 : vector<16xf32>
1865+
}
1866+
1867+
// CHECK-LABEL: func @masked_load_with_alignment
1868+
// CHECK: llvm.intr.masked.load %{{.*}} {alignment = 2 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
1869+
1870+
// -----
1871+
18421872
//===----------------------------------------------------------------------===//
18431873
// vector.maskedstore
18441874
//===----------------------------------------------------------------------===//
@@ -1891,6 +1921,16 @@ func.func @masked_store_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[16
18911921

18921922
// -----
18931923

1924+
func.func @masked_store_with_alignment(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>, %arg3: index) {
1925+
vector.maskedstore %arg0[%arg3], %arg1, %arg2 { alignment = 2 } : memref<?xf32>, vector<16xi1>, vector<16xf32>
1926+
return
1927+
}
1928+
1929+
// CHECK-LABEL: func @masked_store_with_alignment
1930+
// CHECK: llvm.intr.masked.store %{{.*}} {alignment = 2 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr
1931+
1932+
// -----
1933+
18941934
//===----------------------------------------------------------------------===//
18951935
// vector.gather
18961936
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1305,6 +1305,26 @@ func.func @store_memref_index_mismatch(%base : memref<?xf32>, %value : vector<16
13051305

13061306
// -----
13071307

1308+
//===----------------------------------------------------------------------===//
1309+
// vector.maskedload
1310+
//===----------------------------------------------------------------------===//
1311+
1312+
func.func @maskedload_negative_alignment(%base: memref<4xi32>, %mask: vector<32xi1>, %pass: vector<1xi32>, %index: index) {
1313+
// expected-error@below {{'vector.maskedload' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1314+
%val = vector.maskedload %base[%index], %mask, %pass { alignment = -1 } : memref<4xi32>, vector<32xi1>, vector<1xi32> into vector<1xi32>
1315+
return
1316+
}
1317+
1318+
// -----
1319+
1320+
func.func @maskedload_nonpoweroftwo_alignment(%base: memref<4xi32>, %mask: vector<32xi1>, %pass: vector<1xi32>, %index: index) {
1321+
// expected-error@below {{'vector.maskedload' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1322+
%val = vector.maskedload %base[%index], %mask, %pass { alignment = 3 } : memref<4xi32>, vector<32xi1>, vector<1xi32> into vector<1xi32>
1323+
return
1324+
}
1325+
1326+
// -----
1327+
13081328
func.func @maskedload_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass: vector<16xf32>) {
13091329
%c0 = arith.constant 0 : index
13101330
// expected-error@+1 {{'vector.maskedload' op base and result element type should match}}
@@ -1336,6 +1356,26 @@ func.func @maskedload_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>
13361356

13371357
// -----
13381358

1359+
//===----------------------------------------------------------------------===//
1360+
// vector.maskedstore
1361+
//===----------------------------------------------------------------------===//
1362+
1363+
func.func @maskedstore_negative_alignment(%base: memref<4xi32>, %mask: vector<32xi1>, %value: vector<1xi32>, %index: index) {
1364+
// expected-error@below {{'vector.maskedstore' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1365+
vector.maskedstore %base[%index], %mask, %value { alignment = -1 } : memref<4xi32>, vector<32xi1>, vector<1xi32> into vector<1xi32>
1366+
return
1367+
}
1368+
1369+
// -----
1370+
1371+
func.func @maskedstore_nonpoweroftwo_alignment(%base: memref<4xi32>, %mask: vector<32xi1>, %value: vector<1xi32>, %index: index) {
1372+
// expected-error@below {{'vector.maskedstore' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1373+
vector.maskedstore %base[%index], %mask, %value { alignment = 3 } : memref<4xi32>, vector<32xi1>, vector<1xi32> into vector<1xi32>
1374+
return
1375+
}
1376+
1377+
// -----
1378+
13391379
func.func @maskedstore_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
13401380
%c0 = arith.constant 0 : index
13411381
// expected-error@+1 {{'vector.maskedstore' op base and valueToStore element type should match}}
@@ -1912,15 +1952,22 @@ func.func @vector_load(%src : memref<?xi8>) {
19121952

19131953
// -----
19141954

1915-
func.func @invalid_load_alignment(%memref: memref<4xi32>) {
1916-
%c0 = arith.constant 0 : index
1955+
func.func @invalid_load_alignment(%memref: memref<4xi32>, %c0: index) {
19171956
// expected-error @below {{'vector.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
19181957
%val = vector.load %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32>
19191958
return
19201959
}
19211960

19221961
// -----
19231962

1963+
func.func @invalid_load_alignment(%memref: memref<4xi32>, %c0: index) {
1964+
// expected-error @below {{'vector.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1965+
%val = vector.load %memref[%c0] { alignment = 3 } : memref<4xi32>, vector<4xi32>
1966+
return
1967+
}
1968+
1969+
// -----
1970+
19241971
//===----------------------------------------------------------------------===//
19251972
// vector.store
19261973
//===----------------------------------------------------------------------===//
@@ -1934,8 +1981,15 @@ func.func @vector_store(%dest : memref<?xi8>, %vec : vector<16x16xi8>) {
19341981

19351982
// -----
19361983

1937-
func.func @invalid_store_alignment(%memref: memref<4xi32>, %val: vector<4xi32>) {
1938-
%c0 = arith.constant 0 : index
1984+
func.func @invalid_store_alignment(%memref: memref<4xi32>, %val: vector<4xi32>, %c0: index) {
1985+
// expected-error @below {{'vector.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1986+
vector.store %val, %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32>
1987+
return
1988+
}
1989+
1990+
// -----
1991+
1992+
func.func @invalid_store_alignment(%memref: memref<4xi32>, %val: vector<4xi32>, %c0: index) {
19391993
// expected-error @below {{'vector.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
19401994
vector.store %val, %memref[%c0] { alignment = 3 } : memref<4xi32>, vector<4xi32>
19411995
return

0 commit comments

Comments
 (0)