Skip to content

Commit 2d77859

Browse files
add logic result and matchPattern on dynamic position.
1 parent b6b4362 commit 2d77859

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1980,13 +1980,13 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
19801980
// If the dynamic operands of `extractOp` or `insertOp` is result of
19811981
// `constantOp`, then fold it.
19821982
template <typename T>
1983-
static void foldConstantOp(T op, SmallVectorImpl<Value> &operands) {
1983+
static LogicalResult foldConstantOp(T op, SmallVectorImpl<Value> &operands) {
19841984
auto staticPosition = op.getStaticPosition().vec();
19851985
OperandRange dynamicPosition = op.getDynamicPosition();
19861986

19871987
// If the dynamic operands is empty, it is returned directly.
19881988
if (!dynamicPosition.size())
1989-
return;
1989+
return failure();
19901990
unsigned index = 0;
19911991

19921992
// `opChange` is a flog. If it is true, it means to update `op` in place.
@@ -2002,10 +2002,10 @@ static void foldConstantOp(T op, SmallVectorImpl<Value> &operands) {
20022002
continue;
20032003
}
20042004

2005-
if (auto constantOp =
2006-
mlir::dyn_cast<arith::ConstantIndexOp>(position.getDefiningOp())) {
2005+
APInt pos;
2006+
if (matchPattern(position, m_ConstantInt(&pos))) {
20072007
opChange = true;
2008-
staticPosition[i] = constantOp.value();
2008+
staticPosition[i] = pos.getSExtValue();
20092009
continue;
20102010
}
20112011
operands.push_back(position);
@@ -2014,7 +2014,9 @@ static void foldConstantOp(T op, SmallVectorImpl<Value> &operands) {
20142014
if (opChange) {
20152015
op.setStaticPosition(staticPosition);
20162016
op.getOperation()->setOperands(operands);
2017+
return success();
20172018
}
2019+
return failure();
20182020
}
20192021

20202022
OpFoldResult ExtractOp::fold(FoldAdaptor) {
@@ -2040,7 +2042,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor) {
20402042
if (auto val = foldScalarExtractFromFromElements(*this))
20412043
return val;
20422044
SmallVector<Value> operands = {getVector()};
2043-
foldConstantOp(*this, operands);
2045+
if (succeeded(foldConstantOp(*this, operands)))
2046+
return getResult();
20442047
return OpFoldResult();
20452048
}
20462049

@@ -3071,7 +3074,8 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
30713074
if (getNumIndices() == 0 && getSourceType() == getType())
30723075
return getSource();
30733076
SmallVector<Value> operands = {getSource(), getDest()};
3074-
foldConstantOp(*this, operands);
3077+
if (succeeded(foldConstantOp(*this, operands)))
3078+
return getResult();
30753079
return {};
30763080
}
30773081

0 commit comments

Comments
 (0)