Skip to content

Commit 12eead5

Browse files
committed
[mlir][vector] Add alignment to vector.scatter
1 parent da40554 commit 12eead5

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2095,7 +2095,9 @@ def Vector_ScatterOp :
20952095
Variadic<Index>:$indices,
20962096
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
20972097
VectorOfNonZeroRankOf<[I1]>:$mask,
2098-
AnyVectorOfNonZeroRank:$valueToStore)> {
2098+
AnyVectorOfNonZeroRank:$valueToStore,
2099+
ConfinedAttr<OptionalAttr<I64Attr>,
2100+
[AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)> {
20992101

21002102
let summary = [{
21012103
scatters elements from a vector into memory as defined by an index vector
@@ -2153,6 +2155,19 @@ def Vector_ScatterOp :
21532155
"type($index_vec) `,` type($mask) `,` type($valueToStore)";
21542156
let hasCanonicalizer = 1;
21552157
let hasVerifier = 1;
2158+
2159+
let builders = [
2160+
OpBuilder<(ins "Value":$base,
2161+
"ValueRange":$indices,
2162+
"Value":$index_vec,
2163+
"Value":$mask,
2164+
"Value":$valueToStore,
2165+
CArg<"llvm::Align", "llvm::Align()">: $alignment), [{
2166+
return build($_builder, $_state, base, indices, index_vec, mask, valueToStore,
2167+
alignment != llvm::Align() ? $_builder.getI64IntegerAttr(alignment.value()) :
2168+
nullptr);
2169+
}]>
2170+
];
21562171
}
21572172

21582173
def Vector_ExpandLoadOp :

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,6 +1509,24 @@ func.func @scatter_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi
15091509

15101510
// -----
15111511

1512+
func.func @scatter_invalid_alignment(%base: memref<?xf32>, %indices: vector<16xi32>,
1513+
%mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) {
1514+
// expected-error@+1 {{'vector.scatter' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1515+
vector.scatter %base[%c0][%indices], %mask, %value { alignment = -1 }
1516+
: memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
1517+
}
1518+
1519+
// -----
1520+
1521+
func.func @scatter_invalid_alignment(%base: memref<?xf32>, %indices: vector<16xi32>,
1522+
%mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) {
1523+
// expected-error@+1 {{'vector.scatter' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1524+
vector.scatter %base[%c0][%indices], %mask, %value { alignment = 3 }
1525+
: memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
1526+
}
1527+
1528+
// -----
1529+
15121530
func.func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
15131531
%c0 = arith.constant 0 : index
15141532
// expected-error@+1 {{'vector.expandload' op base and result element type should match}}

0 commit comments

Comments
 (0)