-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[MLIR][Vector] Improve warp distribution robustness #161647
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -928,17 +928,20 @@ struct WarpOpDeadResult : public WarpDistributionPattern { | |
| // Some values may be yielded multiple times and correspond to multiple | ||
| // results. Deduplicating occurs by taking each result with its matching | ||
| // yielded value, and: | ||
| // 1. recording the unique first position at which the value is yielded. | ||
| // 1. recording the unique first position at which the value with uses is | ||
| // yielded. | ||
| // 2. recording for the result, the first position at which the dedup'ed | ||
| // value is yielded. | ||
| // 3. skipping from the new result types / new yielded values any result | ||
| // that has no use or whose yielded value has already been seen. | ||
| for (OpResult result : warpOp.getResults()) { | ||
| if (result.use_empty()) | ||
| continue; | ||
| Value yieldOperand = yield.getOperand(result.getResultNumber()); | ||
| auto it = dedupYieldOperandPositionMap.insert( | ||
| std::make_pair(yieldOperand, newResultTypes.size())); | ||
| dedupResultPositionMap.insert(std::make_pair(result, it.first->second)); | ||
| if (result.use_empty() || !it.second) | ||
| if (!it.second) | ||
| continue; | ||
| newResultTypes.push_back(result.getType()); | ||
| newYieldValues.push_back(yieldOperand); | ||
|
|
@@ -1843,16 +1846,16 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { | |
| newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(), | ||
| escapingValueDistTypesElse.end()); | ||
|
|
||
| llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx; | ||
| for (auto [idx, val] : | ||
| llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) { | ||
| origToNewYieldIdx[idx] = newWarpOpYieldValues.size(); | ||
| newWarpOpYieldValues.push_back(val); | ||
| newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType()); | ||
| } | ||
| // Create the new `WarpOp` with the updated yield values and types. | ||
| WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( | ||
| rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes); | ||
| // Replace the old `WarpOp` with the new one that has additional yield | ||
| // values and types. | ||
| SmallVector<size_t> newIndices; | ||
| WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( | ||
| rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the exisiting code is still correct. all other patterns just append the operands so should be scf.for and scf.if. I would try to handle this issue at the high priority op that gets duplicated. We can avoid sinking if is duplicated multiple times and wait for deduplication to hit.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This would actually make a good rule for all patterns in general. As for this particular PR, there is no reason to knowingly allow duplicated values when they can be avoided using existing utilities.
All other distribution patterns use
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Got it. I missed it. |
||
| // `ifOp` returns the result of the inner warp op. | ||
| SmallVector<Type> newIfOpDistResTypes; | ||
| for (auto [i, res] : llvm::enumerate(ifOp.getResults())) { | ||
|
|
@@ -1870,8 +1873,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { | |
| OpBuilder::InsertionGuard g(rewriter); | ||
| rewriter.setInsertionPointAfter(newWarpOp); | ||
| auto newIfOp = scf::IfOp::create( | ||
| rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0), | ||
| static_cast<bool>(ifOp.thenBlock()), | ||
| rewriter, ifOp.getLoc(), newIfOpDistResTypes, | ||
| newWarpOp.getResult(newIndices[0]), static_cast<bool>(ifOp.thenBlock()), | ||
| static_cast<bool>(ifOp.elseBlock())); | ||
| auto encloseRegionInWarpOp = | ||
| [&](Block *oldIfBranch, Block *newIfBranch, | ||
|
|
@@ -1888,7 +1891,7 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { | |
| for (size_t i = 0; i < escapingValues.size(); | ||
| ++i, ++warpResRangeStart) { | ||
| innerWarpInputVals.push_back( | ||
| newWarpOp.getResult(warpResRangeStart)); | ||
| newWarpOp.getResult(newIndices[warpResRangeStart])); | ||
| escapeValToBlockArgIndex[escapingValues[i]] = | ||
| innerWarpInputTypes.size(); | ||
| innerWarpInputTypes.push_back(escapingValueInputTypes[i]); | ||
|
|
@@ -1936,17 +1939,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { | |
| // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp` | ||
| // result. | ||
| for (auto [origIdx, newIdx] : ifResultMapping) | ||
| rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx), | ||
| rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx), | ||
| newIfOp.getResult(newIdx), newIfOp); | ||
| // Similarly, update any users of the `WarpOp` results that were not | ||
| // results of the `IfOp`. | ||
| for (auto [origIdx, newIdx] : origToNewYieldIdx) | ||
| rewriter.replaceAllUsesWith(warpOp.getResult(origIdx), | ||
| newWarpOp.getResult(newIdx)); | ||
| // Remove the original `WarpOp` and `IfOp`, they should not have any uses | ||
| // at this point. | ||
| rewriter.eraseOp(ifOp); | ||
| rewriter.eraseOp(warpOp); | ||
| return success(); | ||
| } | ||
|
|
||
|
|
@@ -2065,19 +2059,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern { | |
| escapingValueDistTypes.begin(), | ||
| escapingValueDistTypes.end()); | ||
| // Next, we insert all non-`ForOp` yielded values and their distributed | ||
| // types. We also create a mapping between the non-`ForOp` yielded value | ||
| // index and the corresponding new `WarpOp` yield value index (needed to | ||
| // update users later). | ||
| llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping; | ||
| // types. | ||
| for (auto [i, v] : | ||
| llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) { | ||
| nonForResultMapping[i] = newWarpOpYieldValues.size(); | ||
| newWarpOpYieldValues.push_back(v); | ||
| newWarpOpDistTypes.push_back(warpOp.getResult(i).getType()); | ||
| } | ||
| // Create the new `WarpOp` with the updated yield values and types. | ||
| WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( | ||
| rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes); | ||
| SmallVector<size_t> newIndices; | ||
| WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( | ||
| rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices); | ||
|
|
||
| // Next, we create a new `ForOp` with the init args yielded by the new | ||
| // `WarpOp`. | ||
|
|
@@ -2086,7 +2077,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern { | |
| // escaping values in the new `WarpOp`. | ||
| SmallVector<Value> newForOpOperands; | ||
| for (size_t i = 0; i < escapingValuesStartIdx; ++i) | ||
| newForOpOperands.push_back(newWarpOp.getResult(i)); | ||
| newForOpOperands.push_back(newWarpOp.getResult(newIndices[i])); | ||
|
|
||
| // Create a new `ForOp` outside the new `WarpOp` region. | ||
| OpBuilder::InsertionGuard g(rewriter); | ||
|
|
@@ -2110,7 +2101,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern { | |
| llvm::SmallDenseMap<Value, int64_t> argIndexMapping; | ||
| for (size_t i = escapingValuesStartIdx; | ||
| i < escapingValuesStartIdx + escapingValues.size(); ++i) { | ||
| innerWarpInput.push_back(newWarpOp.getResult(i)); | ||
| innerWarpInput.push_back(newWarpOp.getResult(newIndices[i])); | ||
| argIndexMapping[escapingValues[i - escapingValuesStartIdx]] = | ||
| innerWarpInputType.size(); | ||
| innerWarpInputType.push_back( | ||
|
|
@@ -2146,20 +2137,11 @@ struct WarpOpScfForOp : public WarpDistributionPattern { | |
| if (!innerWarp.getResults().empty()) | ||
| scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults()); | ||
|
|
||
| // Update the users of original `WarpOp` results that were coming from the | ||
| // Update the users of the new `WarpOp` results that were coming from the | ||
| // original `ForOp` to the corresponding new `ForOp` result. | ||
| for (auto [origIdx, newIdx] : forResultMapping) | ||
| rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx), | ||
| rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx), | ||
| newForOp.getResult(newIdx), newForOp); | ||
| // Similarly, update any users of the `WarpOp` results that were not | ||
| // results of the `ForOp`. | ||
| for (auto [origIdx, newIdx] : nonForResultMapping) | ||
| rewriter.replaceAllUsesWith(warpOp.getResult(origIdx), | ||
| newWarpOp.getResult(newIdx)); | ||
| // Remove the original `WarpOp` and `ForOp`, they should not have any uses | ||
| // at this point. | ||
| rewriter.eraseOp(forOp); | ||
| rewriter.eraseOp(warpOp); | ||
| // Update any users of escaping values that were forwarded to the | ||
| // inner `WarpOp`. These values are now arguments of the inner `WarpOp`. | ||
| newForOp.walk([&](Operation *op) { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.