Skip to content

Commit a4d820f

Browse files
committed
[mlir][vector] Add alignment to expandload
1 parent 36949d1 commit a4d820f

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2226,7 +2226,9 @@ def Vector_ExpandLoadOp :
22262226
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
22272227
Variadic<Index>:$indices,
22282228
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
2229-
AnyVectorOfNonZeroRank:$pass_thru)>,
2229+
AnyVectorOfNonZeroRank:$pass_thru,
2230+
ConfinedAttr<OptionalAttr<I64Attr>,
2231+
[AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)>,
22302232
Results<(outs AnyVectorOfNonZeroRank:$result)> {
22312233

22322234
let summary = "reads elements from memory and spreads them into a vector as defined by a mask";
@@ -2288,6 +2290,29 @@ def Vector_ExpandLoadOp :
22882290
"type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
22892291
let hasCanonicalizer = 1;
22902292
let hasVerifier = 1;
2293+
2294+
let builders = [
2295+
OpBuilder<(ins "VectorType":$resultType,
2296+
"Value":$base,
2297+
"ValueRange":$indices,
2298+
"Value":$mask,
2299+
"Value":$passthrough,
2300+
CArg<"llvm::Align", "llvm::Align()">:$alignment), [{
2301+
return build($_builder, $_state, resultType, base, indices, mask, passthrough,
2302+
alignment != llvm::Align() ? $_builder.getI64IntegerAttr(alignment.value()) :
2303+
nullptr);
2304+
}]>,
2305+
OpBuilder<(ins "TypeRange":$resultTypes,
2306+
"Value":$base,
2307+
"ValueRange":$indices,
2308+
"Value":$mask,
2309+
"Value":$passthrough,
2310+
CArg<"llvm::Align", "llvm::Align()">:$alignment), [{
2311+
return build($_builder, $_state, resultTypes, base, indices, mask, passthrough,
2312+
alignment != llvm::Align() ? $_builder.getI64IntegerAttr(alignment.value()) :
2313+
nullptr);
2314+
}]>
2315+
];
22912316
}
22922317

22932318
def Vector_CompressStoreOp :

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,6 +1607,20 @@ func.func @expand_memref_mismatch(%base: memref<?x?xf32>, %mask: vector<16xi1>,
16071607

16081608
// -----
16091609

1610+
func.func @expand_invalid_alignment(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0: index) {
1611+
// expected-error@+1 {{'vector.expandload' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1612+
%0 = vector.expandload %base[%c0], %mask, %pass_thru { alignment = -1 } : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
1613+
}
1614+
1615+
// -----
1616+
1617+
func.func @expand_invalid_alignment(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0: index) {
1618+
// expected-error@+1 {{'vector.expandload' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
1619+
%0 = vector.expandload %base[%c0], %mask, %pass_thru { alignment = 3 } : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
1620+
}
1621+
1622+
// -----
1623+
16101624
func.func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
16111625
%c0 = arith.constant 0 : index
16121626
// expected-error@+1 {{'vector.compressstore' op base and valueToStore element type should match}}

0 commit comments

Comments
 (0)