Skip to content

Commit da40554

Browse files
committed
[mlir][vector] Add alignment to vector.gather.
1 parent 394c64f commit da40554

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2003,7 +2003,9 @@ def Vector_GatherOp :
20032003
Variadic<Index>:$indices,
20042004
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
20052005
VectorOfNonZeroRankOf<[I1]>:$mask,
2006-
AnyVectorOfNonZeroRank:$pass_thru)>,
2006+
AnyVectorOfNonZeroRank:$pass_thru,
2007+
ConfinedAttr<OptionalAttr<I64Attr>,
2008+
[AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)>,
20072009
Results<(outs AnyVectorOfNonZeroRank:$result)> {
20082010

20092011
let summary = [{
@@ -2060,6 +2062,31 @@ def Vector_GatherOp :
20602062
"`into` type($result)";
20612063
let hasCanonicalizer = 1;
20622064
let hasVerifier = 1;
2065+
2066+
let builders = [
2067+
OpBuilder<(ins "VectorType":$resultType,
2068+
"Value":$base,
2069+
"ValueRange":$indices,
2070+
"Value":$index_vec,
2071+
"Value":$mask,
2072+
"Value":$passthrough,
2073+
CArg<"llvm::Align", "llvm::Align()">:$alignment), [{
2074+
return build($_builder, $_state, resultType, base, indices, index_vec, mask, passthrough,
2075+
alignment != llvm::Align() ? $_builder.getI64IntegerAttr(alignment.value()) :
2076+
nullptr);
2077+
}]>,
2078+
OpBuilder<(ins "TypeRange":$resultTypes,
2079+
"Value":$base,
2080+
"ValueRange":$indices,
2081+
"Value":$index_vec,
2082+
"Value":$mask,
2083+
"Value":$passthrough,
2084+
CArg<"llvm::Align", "llvm::Align()">:$alignment), [{
2085+
return build($_builder, $_state, resultTypes, base, indices, index_vec, mask, passthrough,
2086+
alignment != llvm::Align() ? $_builder.getI64IntegerAttr(alignment.value()) :
2087+
nullptr);
2088+
}]>
2089+
];
20632090
}
20642091

20652092
def Vector_ScatterOp :

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1430,6 +1430,24 @@ func.func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector
14301430

14311431
// -----
14321432

1433+
func.func @gather_invalid_alignment(%base: memref<16xf32>, %indices: vector<16xi32>,
1434+
%mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0 : index) {
1435+
// expected-error@+2 {{'vector.gather' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1436+
%0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
1437+
{ alignment = -1 } : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
1438+
}
1439+
1440+
// -----
1441+
1442+
func.func @gather_invalid_alignment(%base: memref<16xf32>, %indices: vector<16xi32>,
1443+
%mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0 : index) {
1444+
// expected-error@+2 {{'vector.gather' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1445+
%0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
1446+
{ alignment = 3 } : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
1447+
}
1448+
1449+
// -----
1450+
14331451
func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
14341452
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
14351453
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)