Skip to content

Commit 158b91c

Browse files
committed
[mlir][vector] Add alignment to vector.scatter
1 parent a97cbc6 commit 158b91c

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2146,7 +2146,9 @@ def Vector_ScatterOp :
21462146
Variadic<Index>:$indices,
21472147
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
21482148
VectorOfNonZeroRankOf<[I1]>:$mask,
2149-
AnyVectorOfNonZeroRank:$valueToStore)> {
2149+
AnyVectorOfNonZeroRank:$valueToStore,
2150+
ConfinedAttr<OptionalAttr<I64Attr>,
2151+
[AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)> {
21502152

21512153
let summary = [{
21522154
scatters elements from a vector into memory as defined by an index vector
@@ -2204,6 +2206,19 @@ def Vector_ScatterOp :
22042206
"type($index_vec) `,` type($mask) `,` type($valueToStore)";
22052207
let hasCanonicalizer = 1;
22062208
let hasVerifier = 1;
2209+
2210+
let builders = [
2211+
OpBuilder<(ins "Value":$base,
2212+
"ValueRange":$indices,
2213+
"Value":$index_vec,
2214+
"Value":$mask,
2215+
"Value":$valueToStore,
2216+
CArg<"llvm::Align", "llvm::Align()">: $alignment), [{
2217+
return build($_builder, $_state, base, indices, index_vec, mask, valueToStore,
2218+
alignment != llvm::Align() ? $_builder.getI64IntegerAttr(alignment.value()) :
2219+
nullptr);
2220+
}]>
2221+
];
22072222
}
22082223

22092224
def Vector_ExpandLoadOp :

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1549,6 +1549,24 @@ func.func @scatter_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi
15491549

15501550
// -----
15511551

1552+
func.func @scatter_invalid_alignment(%base: memref<?xf32>, %indices: vector<16xi32>,
1553+
%mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) {
1554+
// expected-error@+1 {{'vector.scatter' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1555+
vector.scatter %base[%c0][%indices], %mask, %value { alignment = -1 }
1556+
: memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
1557+
}
1558+
1559+
// -----
1560+
1561+
func.func @scatter_invalid_alignment(%base: memref<?xf32>, %indices: vector<16xi32>,
1562+
%mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) {
1563+
// expected-error@+1 {{'vector.scatter' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1564+
vector.scatter %base[%c0][%indices], %mask, %value { alignment = 3 }
1565+
: memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
1566+
}
1567+
1568+
// -----
1569+
15521570
func.func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
15531571
%c0 = arith.constant 0 : index
15541572
// expected-error@+1 {{'vector.expandload' op base and result element type should match}}

0 commit comments

Comments
 (0)