Skip to content

Commit 32b9f8c

Browse files
author
Yang Bai
committed
[mlir][vector] Support complete folding in single pass for vector.insert/vector.extract
After successfully converting dynamic indices to static indices, continue folding instead of returning early, allowing subsequent fold operations to be executed.
1 parent 7c99601 commit 32b9f8c

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2062,6 +2062,7 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
20622062
if (opChange) {
20632063
op.setStaticPosition(staticPosition);
20642064
op.getOperation()->setOperands(operands);
2065+
// Return the original result to indicate an in-place folding happened.
20652066
return op.getResult();
20662067
}
20672068
return {};
@@ -2148,8 +2149,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
21482149
// Fold `arith.constant` indices into the `vector.extract` operation. Make
21492150
// sure that patterns requiring constant indices are added after this fold.
21502151
SmallVector<Value> operands = {getVector()};
2151-
if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
2152-
return val;
2152+
auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
2153+
21532154
if (auto res = foldPoisonIndexInsertExtractOp(
21542155
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
21552156
return res;
@@ -2171,7 +2172,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
21712172
return val;
21722173
if (auto val = foldScalarExtractFromFromElements(*this))
21732174
return val;
2174-
return OpFoldResult();
2175+
2176+
return inplaceFolded;
21752177
}
21762178

21772179
namespace {
@@ -3150,8 +3152,8 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
31503152
// Fold `arith.constant` indices into the `vector.insert` operation. Make
31513153
// sure that patterns requiring constant indices are added after this fold.
31523154
SmallVector<Value> operands = {getValueToStore(), getDest()};
3153-
if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
3154-
return val;
3155+
auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
3156+
31553157
if (auto res = foldPoisonIndexInsertExtractOp(
31563158
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
31573159
return res;
@@ -3161,7 +3163,7 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
31613163
return res;
31623164
}
31633165

3164-
return {};
3166+
return inplaceFolded;
31653167
}
31663168

31673169
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)