-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][spirv] Fix FuncOpVectorUnroll to process placeholder values in all blocks #142339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
393ef36
c720c6b
190b9aa
e496a3e
4f2616d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1020,35 +1020,39 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> { | |
| SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc()); | ||
| entryBlock.addArguments(convertedTypes, locs); | ||
|
|
||
| // Replace the placeholder values with the new arguments. We assume there is | ||
| // only one block for now. | ||
| // Replace the placeholder values with the new arguments. | ||
| size_t unrolledInputIdx = 0; | ||
| for (auto [count, op] : enumerate(entryBlock.getOperations())) { | ||
| newFuncOp.walk([&](Operation *op) { | ||
| // We first look for operands that are placeholders for initially legal | ||
| // arguments. | ||
| Operation &curOp = op; | ||
| for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) { | ||
| for (auto [operandIdx, operandVal] : llvm::enumerate(op->getOperands())) { | ||
| Operation *operandOp = operandVal.getDefiningOp(); | ||
| if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) { | ||
| size_t idx = operandIdx; | ||
| rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] { | ||
| curOp.setOperand(idx, newFuncOp.getArgument(it->second)); | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code originally only processes one block, but the placeholder values for the function arguments were applied with |
||
| rewriter.modifyOpInPlace(op, [&] { | ||
| op->setOperand(idx, newFuncOp.getArgument(it->second)); | ||
| }); | ||
| } | ||
| } | ||
| // Since all newly created operations are in the beginning, reaching the | ||
| // end of them means that any later `vector.insert_strided_slice` should | ||
| // not be touched. | ||
| if (count >= newOpCount) | ||
| continue; | ||
|
|
||
| // Only consider `vector.insert_strided_slice` ops that were newly created | ||
| // at the beginning of the entry block. Once we encounter operations | ||
| // outside the entry block or past the `newOpCount`-th operation in the | ||
| // entry block, we skip and leave exisintg `vector.insert_strided_slice` | ||
| // ops as is. | ||
| if (op->getBlock() != &entryBlock || | ||
| static_cast<size_t>(std::distance(entryBlock.begin(), | ||
| op->getIterator())) >= newOpCount) | ||
| return; | ||
|
|
||
| if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) { | ||
| size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx]; | ||
| rewriter.modifyOpInPlace(&curOp, [&] { | ||
| curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo)); | ||
| rewriter.modifyOpInPlace(op, [&] { | ||
| op->setOperand(0, newFuncOp.getArgument(unrolledInputNo)); | ||
| }); | ||
| ++unrolledInputIdx; | ||
| } | ||
| } | ||
| }); | ||
|
|
||
| // Erase the original funcOp. The `tmpOps` do not need to be erased since | ||
| // they have no uses and will be handled by dead-code elimination. | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.