Skip to content

Commit a97cbc6

Browse files
committed
[mlir][vector] Add alignment to vector.gather.
1 parent ef30dd3 commit a97cbc6

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2054,7 +2054,9 @@ def Vector_GatherOp :
20542054
Variadic<Index>:$indices,
20552055
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
20562056
VectorOfNonZeroRankOf<[I1]>:$mask,
2057-
AnyVectorOfNonZeroRank:$pass_thru)>,
2057+
AnyVectorOfNonZeroRank:$pass_thru,
2058+
ConfinedAttr<OptionalAttr<I64Attr>,
2059+
[AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)>,
20582060
Results<(outs AnyVectorOfNonZeroRank:$result)> {
20592061

20602062
let summary = [{
@@ -2111,6 +2113,31 @@ def Vector_GatherOp :
21112113
"`into` type($result)";
21122114
let hasCanonicalizer = 1;
21132115
let hasVerifier = 1;
2116+
2117+
let builders = [
2118+
OpBuilder<(ins "VectorType":$resultType,
2119+
"Value":$base,
2120+
"ValueRange":$indices,
2121+
"Value":$index_vec,
2122+
"Value":$mask,
2123+
"Value":$passthrough,
2124+
CArg<"llvm::Align", "llvm::Align()">:$alignment), [{
2125+
return build($_builder, $_state, resultType, base, indices, index_vec, mask, passthrough,
2126+
alignment != llvm::Align() ? $_builder.getI64IntegerAttr(alignment.value()) :
2127+
nullptr);
2128+
}]>,
2129+
OpBuilder<(ins "TypeRange":$resultTypes,
2130+
"Value":$base,
2131+
"ValueRange":$indices,
2132+
"Value":$index_vec,
2133+
"Value":$mask,
2134+
"Value":$passthrough,
2135+
CArg<"llvm::Align", "llvm::Align()">:$alignment), [{
2136+
return build($_builder, $_state, resultTypes, base, indices, index_vec, mask, passthrough,
2137+
alignment != llvm::Align() ? $_builder.getI64IntegerAttr(alignment.value()) :
2138+
nullptr);
2139+
}]>
2140+
];
21142141
}
21152142

21162143
def Vector_ScatterOp :

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1470,6 +1470,24 @@ func.func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector
14701470

14711471
// -----
14721472

1473+
func.func @gather_invalid_alignment(%base: memref<16xf32>, %indices: vector<16xi32>,
1474+
%mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0 : index) {
1475+
// expected-error@+2 {{'vector.gather' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1476+
%0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
1477+
{ alignment = -1 } : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
1478+
}
1479+
1480+
// -----
1481+
1482+
func.func @gather_invalid_alignment(%base: memref<16xf32>, %indices: vector<16xi32>,
1483+
%mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0 : index) {
1484+
// expected-error@+2 {{'vector.gather' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1485+
%0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
1486+
{ alignment = 3 } : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
1487+
}
1488+
1489+
// -----
1490+
14731491
func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
14741492
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
14751493
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)