@@ -1020,39 +1020,33 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
10201020 SmallVector<Location> locs (convertedTypes.size (), newFuncOp.getLoc ());
10211021 entryBlock.addArguments (convertedTypes, locs);
10221022
1023- // Replace the placeholder values with the new arguments.
1024- size_t unrolledInputIdx = 0 ;
1025- newFuncOp.walk ([&](Operation *op) {
1026- // We first look for operands that are placeholders for initially legal
1027- // arguments.
1028- for (auto [operandIdx, operandVal] : llvm::enumerate (op->getOperands ())) {
1029- Operation *operandOp = operandVal.getDefiningOp ();
1030- if (auto it = tmpOps.find (operandOp); it != tmpOps.end ()) {
1031- size_t idx = operandIdx;
1032- rewriter.modifyOpInPlace (op, [&] {
1033- op->setOperand (idx, newFuncOp.getArgument (it->second ));
1034- });
1035- }
1036- }
1037-
1038- // Only consider `vector.insert_strided_slice` ops that were newly created
1039- // at the beginning of the entry block. Once we encounter operations
1040- // outside the entry block or past the `newOpCount`-th operation in the
1041- // entry block, we skip and leave exisintg `vector.insert_strided_slice`
1042- // ops as is.
1043- if (op->getBlock () != &entryBlock ||
1044- static_cast <size_t >(std::distance (entryBlock.begin (),
1045- op->getIterator ())) >= newOpCount)
1046- return ;
1023+ // Replace all uses of placeholders for initially legal arguments with their
1024+ // original function arguments (that were added to `newFuncOp`).
1025+ for (auto &[placeholderOp, argIdx] : tmpOps) {
1026+ if (!placeholderOp)
1027+ continue ;
1028+ Value replacement = newFuncOp.getArgument (argIdx);
1029+ rewriter.replaceAllUsesWith (placeholderOp->getResult (0 ), replacement);
1030+ }
10471031
1032+ // Replace dummy operands of new `vector.insert_strided_slice` ops with
1033+ // their corresponding new function arguments.
1034+ size_t unrolledInputIdx = 0 ;
1035+ for (auto [count, op] : enumerate(entryBlock.getOperations ())) {
1036+ Operation &curOp = op;
1037+ // Since all newly created operations are in the beginning, reaching the
1038+ // end of them means that any later `vector.insert_strided_slice` should
1039+ // not be touched.
1040+ if (count >= newOpCount)
1041+ continue ;
10481042 if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
10491043 size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
1050- rewriter.modifyOpInPlace (op , [&] {
1051- op-> setOperand (0 , newFuncOp.getArgument (unrolledInputNo));
1044+ rewriter.modifyOpInPlace (&curOp , [&] {
1045+ curOp. setOperand (0 , newFuncOp.getArgument (unrolledInputNo));
10521046 });
10531047 ++unrolledInputIdx;
10541048 }
1055- });
1049+ }
10561050
10571051 // Erase the original funcOp. The `tmpOps` do not need to be erased since
10581052 // they have no uses and will be handled by dead-code elimination.
0 commit comments