diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index bbb366b01fa6e..cf2df1f24f91f 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2143,11 +2143,16 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) { // mismatch). if (getNumIndices() == 0 && getVector().getType() == getResult().getType()) return getVector(); + if (auto res = foldPoisonSrcExtractOp(adaptor.getVector())) + return res; + // Fold `arith.constant` indices into the `vector.extract` operation. Make + // sure that patterns requiring constant indices are added after this fold. + SmallVector operands = {getVector()}; + if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands)) + return val; if (auto res = foldPoisonIndexInsertExtractOp( getContext(), adaptor.getStaticPosition(), kPoisonIndex)) return res; - if (auto res = foldPoisonSrcExtractOp(adaptor.getVector())) - return res; if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getVector())) return res; if (succeeded(foldExtractOpFromExtractChain(*this))) @@ -2166,9 +2171,6 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) { return val; if (auto val = foldScalarExtractFromFromElements(*this)) return val; - SmallVector operands = {getVector()}; - if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands)) - return val; return OpFoldResult(); } @@ -3145,6 +3147,8 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) { // (type mismatch). if (getNumIndices() == 0 && getValueToStoreType() == getType()) return getValueToStore(); + // Fold `arith.constant` indices into the `vector.insert` operation. Make + // sure that patterns requiring constant indices are added after this fold. SmallVector operands = {getValueToStore(), getDest()}; if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands)) return val;