Skip to content

Commit e496a3e

Browse files
committed
Improve implementation to note use walk
1 parent 190b9aa commit e496a3e

File tree

1 file changed

+21
-27
lines changed

1 file changed

+21
-27
lines changed

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,39 +1020,33 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
10201020
SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
10211021
entryBlock.addArguments(convertedTypes, locs);
10221022

1023-
// Replace the placeholder values with the new arguments.
1024-
size_t unrolledInputIdx = 0;
1025-
newFuncOp.walk([&](Operation *op) {
1026-
// We first look for operands that are placeholders for initially legal
1027-
// arguments.
1028-
for (auto [operandIdx, operandVal] : llvm::enumerate(op->getOperands())) {
1029-
Operation *operandOp = operandVal.getDefiningOp();
1030-
if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) {
1031-
size_t idx = operandIdx;
1032-
rewriter.modifyOpInPlace(op, [&] {
1033-
op->setOperand(idx, newFuncOp.getArgument(it->second));
1034-
});
1035-
}
1036-
}
1037-
1038-
// Only consider `vector.insert_strided_slice` ops that were newly created
1039-
// at the beginning of the entry block. Once we encounter operations
1040-
// outside the entry block or past the `newOpCount`-th operation in the
1041-
// entry block, we skip and leave exisintg `vector.insert_strided_slice`
1042-
// ops as is.
1043-
if (op->getBlock() != &entryBlock ||
1044-
static_cast<size_t>(std::distance(entryBlock.begin(),
1045-
op->getIterator())) >= newOpCount)
1046-
return;
1023+
// Replace all uses of placeholders for initially legal arguments with their
1024+
// original function arguments (that were added to `newFuncOp`).
1025+
for (auto &[placeholderOp, argIdx] : tmpOps) {
1026+
if (!placeholderOp)
1027+
continue;
1028+
Value replacement = newFuncOp.getArgument(argIdx);
1029+
rewriter.replaceAllUsesWith(placeholderOp->getResult(0), replacement);
1030+
}
10471031

1032+
// Replace dummy operands of new `vector.insert_strided_slice` ops with
1033+
// their corresponding new function arguments.
1034+
size_t unrolledInputIdx = 0;
1035+
for (auto [count, op] : enumerate(entryBlock.getOperations())) {
1036+
Operation &curOp = op;
1037+
// Since all newly created operations are in the beginning, reaching the
1038+
// end of them means that any later `vector.insert_strided_slice` should
1039+
// not be touched.
1040+
if (count >= newOpCount)
1041+
continue;
10481042
if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
10491043
size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
1050-
rewriter.modifyOpInPlace(op, [&] {
1051-
op->setOperand(0, newFuncOp.getArgument(unrolledInputNo));
1044+
rewriter.modifyOpInPlace(&curOp, [&] {
1045+
curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
10521046
});
10531047
++unrolledInputIdx;
10541048
}
1055-
});
1049+
}
10561050

10571051
// Erase the original funcOp. The `tmpOps` do not need to be erased since
10581052
// they have no uses and will be handled by dead-code elimination.

0 commit comments

Comments
 (0)