Skip to content

Commit 96ba831

Browse files
committed
[mlir][vector] Add alignment to maskedload.
1 parent 3b5aff5 commit 96ba831

File tree

2 files changed

+54
-1
lines changed

2 files changed

+54
-1
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1876,7 +1876,9 @@ def Vector_MaskedLoadOp :
18761876
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
18771877
Variadic<Index>:$indices,
18781878
VectorOfNonZeroRankOf<[I1]>:$mask,
1879-
AnyVectorOfNonZeroRank:$pass_thru)>,
1879+
AnyVectorOfNonZeroRank:$pass_thru,
1880+
ConfinedAttr<OptionalAttr<I64Attr>,
1881+
[AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)>,
18801882
Results<(outs AnyVectorOfNonZeroRank:$result)> {
18811883

18821884
let summary = "loads elements from memory into a vector as defined by a mask vector";
@@ -1912,6 +1914,12 @@ def Vector_MaskedLoadOp :
19121914
%1 = vector.maskedload %base[%i, %j], %mask, %pass_thru
19131915
: memref<?x?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
19141916
```
1917+
1918+
An optional `alignment` attribute allows to specify the byte alignment of the
1919+
load operation. It must be a positive power of 2. The operation must access
1920+
memory at an address aligned to this boundary. Violations may lead to
1921+
architecture-specific faults or performance penalties.
1922+
A value of 0 indicates no specific alignment requirement.
19151923
}];
19161924
let extraClassDeclaration = [{
19171925
MemRefType getMemRefType() {
@@ -1932,6 +1940,29 @@ def Vector_MaskedLoadOp :
19321940
let hasCanonicalizer = 1;
19331941
let hasFolder = 1;
19341942
let hasVerifier = 1;
1943+
1944+
let builders = [
1945+
OpBuilder<(ins "VectorType":$resultType,
1946+
"Value":$base,
1947+
"ValueRange":$indices,
1948+
"Value":$mask,
1949+
"Value":$passthrough,
1950+
CArg<"uint64_t", "0">:$alignment), [{
1951+
return build($_builder, $_state, resultType, base, indices, mask, passthrough,
1952+
alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
1953+
nullptr);
1954+
}]>,
1955+
OpBuilder<(ins "TypeRange":$resultTypes,
1956+
"Value":$base,
1957+
"ValueRange":$indices,
1958+
"Value":$mask,
1959+
"Value":$passthrough,
1960+
CArg<"uint64_t", "0">:$alignment), [{
1961+
return build($_builder, $_state, resultTypes, base, indices, mask, passthrough,
1962+
alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
1963+
nullptr);
1964+
}]>
1965+
];
19351966
}
19361967

19371968
def Vector_MaskedStoreOp :

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1305,6 +1305,28 @@ func.func @store_memref_index_mismatch(%base : memref<?xf32>, %value : vector<16
13051305

13061306
// -----
13071307

1308+
//===----------------------------------------------------------------------===//
1309+
// vector.maskedload
1310+
//===----------------------------------------------------------------------===//
1311+
1312+
func.func @maskedload_negative_alignment(%base: memref<?xf32>, %mask: vector<16xi1>, %pass: vector<16xf32>) {
1313+
%c0 = arith.constant 0 : index
1314+
// expected-error@+1 {{'vector.maskedload' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1315+
%val = vector.maskedload %base[%c0], %mask, %pass { alignment = -1 } : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
1316+
return
1317+
}
1318+
1319+
// -----
1320+
1321+
func.func @maskedload_nonpower2_alignment(%base: memref<?xf32>, %mask: vector<16xi1>, %pass: vector<16xf32>) {
1322+
%c0 = arith.constant 0 : index
1323+
// expected-error@+1 {{'vector.maskedload' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1324+
%val = vector.maskedload %base[%c0], %mask, %pass { alignment = 3 } : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
1325+
return
1326+
}
1327+
1328+
// -----
1329+
13081330
func.func @maskedload_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass: vector<16xf32>) {
13091331
%c0 = arith.constant 0 : index
13101332
// expected-error@+1 {{'vector.maskedload' op base and result element type should match}}

0 commit comments

Comments
 (0)