-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][spirv] Fix FuncOpVectorUnroll to process placeholder values in all blocks #142339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir Author: Darren Wihandi (fairywreath) Changes
The current implementation however only replaces back (the second replacement, i.e. replacing the placeholder values to new/legal arguments) the first block of instructions and not all of the blocks. This may leave some instructions to use these placeholder values (which for already legal arguments are just zeroattr values that will get DCE'd) instead of the arguments, which is incorrect. Closes #132158. TODO: add test Full diff: https://github.com/llvm/llvm-project/pull/142339.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 62a24646d0662..84796fdeda03a 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1020,35 +1020,37 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
entryBlock.addArguments(convertedTypes, locs);
- // Replace the placeholder values with the new arguments. We assume there is
- // only one block for now.
+ // Replace the placeholder values with the new arguments.
size_t unrolledInputIdx = 0;
- for (auto [count, op] : enumerate(entryBlock.getOperations())) {
+ newFuncOp.walk([&](Operation *op) {
// We first look for operands that are placeholders for initially legal
// arguments.
- Operation &curOp = op;
- for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) {
+ for (auto [operandIdx, operandVal] : llvm::enumerate(op->getOperands())) {
Operation *operandOp = operandVal.getDefiningOp();
if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) {
size_t idx = operandIdx;
- rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] {
- curOp.setOperand(idx, newFuncOp.getArgument(it->second));
+ rewriter.modifyOpInPlace(op, [&] {
+ op->setOperand(idx, newFuncOp.getArgument(it->second));
});
}
}
+
// Since all newly created operations are in the beginning, reaching the
// end of them means that any later `vector.insert_strided_slice` should
// not be touched.
- if (count >= newOpCount)
- continue;
+ if (op->getBlock() == &entryBlock &&
+ static_cast<size_t>(std::distance(entryBlock.begin(),
+ op->getIterator())) >= newOpCount)
+ return;
+
if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
- rewriter.modifyOpInPlace(&curOp, [&] {
- curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
+ rewriter.modifyOpInPlace(op, [&] {
+ op->setOperand(0, newFuncOp.getArgument(unrolledInputNo));
});
++unrolledInputIdx;
}
- }
+ });
// Erase the original funcOp. The `tmpOps` do not need to be erased since
// they have no uses and will be handled by dead-code elimination.
|
|
@llvm/pr-subscribers-mlir-spirv Author: Darren Wihandi (fairywreath) Changes
The current implementation however only replaces back (the second replacement, i.e. replacing the placeholder values to new/legal arguments) the first block of instructions and not all of the blocks. This may leave some instructions to use these placeholder values (which for already legal arguments are just zeroattr values that will get DCE'd) instead of the arguments, which is incorrect. Closes #132158. TODO: add test Full diff: https://github.com/llvm/llvm-project/pull/142339.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 62a24646d0662..84796fdeda03a 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1020,35 +1020,37 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
entryBlock.addArguments(convertedTypes, locs);
- // Replace the placeholder values with the new arguments. We assume there is
- // only one block for now.
+ // Replace the placeholder values with the new arguments.
size_t unrolledInputIdx = 0;
- for (auto [count, op] : enumerate(entryBlock.getOperations())) {
+ newFuncOp.walk([&](Operation *op) {
// We first look for operands that are placeholders for initially legal
// arguments.
- Operation &curOp = op;
- for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) {
+ for (auto [operandIdx, operandVal] : llvm::enumerate(op->getOperands())) {
Operation *operandOp = operandVal.getDefiningOp();
if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) {
size_t idx = operandIdx;
- rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] {
- curOp.setOperand(idx, newFuncOp.getArgument(it->second));
+ rewriter.modifyOpInPlace(op, [&] {
+ op->setOperand(idx, newFuncOp.getArgument(it->second));
});
}
}
+
// Since all newly created operations are in the beginning, reaching the
// end of them means that any later `vector.insert_strided_slice` should
// not be touched.
- if (count >= newOpCount)
- continue;
+ if (op->getBlock() == &entryBlock &&
+ static_cast<size_t>(std::distance(entryBlock.begin(),
+ op->getIterator())) >= newOpCount)
+ return;
+
if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
- rewriter.modifyOpInPlace(&curOp, [&] {
- curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
+ rewriter.modifyOpInPlace(op, [&] {
+ op->setOperand(0, newFuncOp.getArgument(unrolledInputNo));
});
++unrolledInputIdx;
}
- }
+ });
// Erase the original funcOp. The `tmpOps` do not need to be erased since
// they have no uses and will be handled by dead-code elimination.
|
d5150fe to
190b9aa
Compare
| size_t unrolledInputIdx = 0; | ||
| for (auto [count, op] : enumerate(entryBlock.getOperations())) { | ||
| newFuncOp.walk([&](Operation *op) { | ||
| // We first look for operands that are placeholders for initially legal | ||
| // arguments. | ||
| Operation &curOp = op; | ||
| for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) { | ||
| for (auto [operandIdx, operandVal] : llvm::enumerate(op->getOperands())) { | ||
| Operation *operandOp = operandVal.getDefiningOp(); | ||
| if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) { | ||
| size_t idx = operandIdx; | ||
| rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] { | ||
| curOp.setOperand(idx, newFuncOp.getArgument(it->second)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code originally only processes one block, but the placeholder values for the function arguments were applied with replaceAllUsesWith and may replace the function arguments across multiple blocks. The current implementation leaves the placeholder values in the other blocks (except for the first one) as is instead of the original function arguments, which is incorrect.
kuhar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks for fixing this
… all blocks (llvm#142339) `FuncOpVectorUnroll` contains logic that replaces function arguments by placeholders values. These replacements also involve changing all instructions in the function that use the arguments to use these placeholders. These placeholder values will later be changed back to use the function arguments (either new or original if already legal). The current implementation however only replaces back (the second replacement, i.e. replacing the placeholder values to new/legal arguments) the first block of instructions and not all of the blocks. This may leave some instructions to use these placeholder values (which for already legal arguments are just zeroattr values that will get DCE'd) instead of the arguments, which is incorrect. Closes llvm#132158.
… all blocks (llvm#142339) `FuncOpVectorUnroll` contains logic that replaces function arguments by placeholders values. These replacements also involve changing all instructions in the function that use the arguments to use these placeholders. These placeholder values will later be changed back to use the function arguments (either new or original if already legal). The current implementation however only replaces back (the second replacement, i.e. replacing the placeholder values to new/legal arguments) the first block of instructions and not all of the blocks. This may leave some instructions to use these placeholder values (which for already legal arguments are just zeroattr values that will get DCE'd) instead of the arguments, which is incorrect. Closes llvm#132158.
FuncOpVectorUnrollcontains logic that replaces function arguments by placeholders values. These replacements also involve changing all instructions in the function that use the arguments to use these placeholders. These placeholder values will later be changed back to use the function arguments (either new or original if already legal).The current implementation however only replaces back (the second replacement, i.e. replacing the placeholder values to new/legal arguments) the first block of instructions and not all of the blocks. This may leave some instructions to use these placeholder values (which for already legal arguments are just zeroattr values that will get DCE'd) instead of the arguments, which is incorrect.
Closes #132158.