Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 22 additions & 40 deletions mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -928,17 +928,20 @@ struct WarpOpDeadResult : public WarpDistributionPattern {
// Some values may be yielded multiple times and correspond to multiple
// results. Deduplicating occurs by taking each result with its matching
// yielded value, and:
// 1. recording the unique first position at which the value is yielded.
// 1. recording the unique first position at which the value with uses is
// yielded.
// 2. recording for the result, the first position at which the dedup'ed
// value is yielded.
// 3. skipping from the new result types / new yielded values any result
// that has no use or whose yielded value has already been seen.
for (OpResult result : warpOp.getResults()) {
if (result.use_empty())
continue;
Value yieldOperand = yield.getOperand(result.getResultNumber());
auto it = dedupYieldOperandPositionMap.insert(
std::make_pair(yieldOperand, newResultTypes.size()));
dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
if (result.use_empty() || !it.second)
if (!it.second)
continue;
newResultTypes.push_back(result.getType());
newYieldValues.push_back(yieldOperand);
Expand Down Expand Up @@ -1843,16 +1846,16 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
escapingValueDistTypesElse.end());

llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx;
for (auto [idx, val] :
llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
origToNewYieldIdx[idx] = newWarpOpYieldValues.size();
newWarpOpYieldValues.push_back(val);
newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
}
// Create the new `WarpOp` with the updated yield values and types.
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
// Replace the old `WarpOp` with the new one that has additional yield
// values and types.
SmallVector<size_t> newIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the exisiting code is still correct. all other patterns just append the operands so should be scf.for and scf.if.

I would try to handle this issue at the high priority op that gets duplicated. We can avoid sinking if is duplicated multiple times and wait for deduplication to hit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why can't the higher rank pattern check if its result is duplicated in the yield list

This would actually make a good rule for all patterns in general. As for this particular PR, there is no reason to knowingly allow duplicated values when they can be avoided using existing utilities.

all other patterns just append the operands so should be scf.for and scf.if

All other distribution patterns use moveRegionToNewWarpOpAndAppendReturns (which does deduplication), not moveRegionToNewWarpOpAndReplaceReturns which simply sets the result types. So in this PR, we actually align with other patterns.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So in this PR, we actually align with other patterns.

Got it. I missed it.

// `ifOp` returns the result of the inner warp op.
SmallVector<Type> newIfOpDistResTypes;
for (auto [i, res] : llvm::enumerate(ifOp.getResults())) {
Expand All @@ -1870,8 +1873,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);
auto newIfOp = scf::IfOp::create(
rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0),
static_cast<bool>(ifOp.thenBlock()),
rewriter, ifOp.getLoc(), newIfOpDistResTypes,
newWarpOp.getResult(newIndices[0]), static_cast<bool>(ifOp.thenBlock()),
static_cast<bool>(ifOp.elseBlock()));
auto encloseRegionInWarpOp =
[&](Block *oldIfBranch, Block *newIfBranch,
Expand All @@ -1888,7 +1891,7 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
for (size_t i = 0; i < escapingValues.size();
++i, ++warpResRangeStart) {
innerWarpInputVals.push_back(
newWarpOp.getResult(warpResRangeStart));
newWarpOp.getResult(newIndices[warpResRangeStart]));
escapeValToBlockArgIndex[escapingValues[i]] =
innerWarpInputTypes.size();
innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
Expand Down Expand Up @@ -1936,17 +1939,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
// Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
// result.
for (auto [origIdx, newIdx] : ifResultMapping)
rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
newIfOp.getResult(newIdx), newIfOp);
// Similarly, update any users of the `WarpOp` results that were not
// results of the `IfOp`.
for (auto [origIdx, newIdx] : origToNewYieldIdx)
rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
newWarpOp.getResult(newIdx));
// Remove the original `WarpOp` and `IfOp`, they should not have any uses
// at this point.
rewriter.eraseOp(ifOp);
rewriter.eraseOp(warpOp);
return success();
}

Expand Down Expand Up @@ -2065,19 +2059,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
escapingValueDistTypes.begin(),
escapingValueDistTypes.end());
// Next, we insert all non-`ForOp` yielded values and their distributed
// types. We also create a mapping between the non-`ForOp` yielded value
// index and the corresponding new `WarpOp` yield value index (needed to
// update users later).
llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping;
// types.
for (auto [i, v] :
llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
nonForResultMapping[i] = newWarpOpYieldValues.size();
newWarpOpYieldValues.push_back(v);
newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
}
// Create the new `WarpOp` with the updated yield values and types.
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
SmallVector<size_t> newIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);

// Next, we create a new `ForOp` with the init args yielded by the new
// `WarpOp`.
Expand All @@ -2086,7 +2077,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
// escaping values in the new `WarpOp`.
SmallVector<Value> newForOpOperands;
for (size_t i = 0; i < escapingValuesStartIdx; ++i)
newForOpOperands.push_back(newWarpOp.getResult(i));
newForOpOperands.push_back(newWarpOp.getResult(newIndices[i]));

// Create a new `ForOp` outside the new `WarpOp` region.
OpBuilder::InsertionGuard g(rewriter);
Expand All @@ -2110,7 +2101,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
for (size_t i = escapingValuesStartIdx;
i < escapingValuesStartIdx + escapingValues.size(); ++i) {
innerWarpInput.push_back(newWarpOp.getResult(i));
innerWarpInput.push_back(newWarpOp.getResult(newIndices[i]));
argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
innerWarpInputType.size();
innerWarpInputType.push_back(
Expand Down Expand Up @@ -2146,20 +2137,11 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
if (!innerWarp.getResults().empty())
scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults());

// Update the users of original `WarpOp` results that were coming from the
// Update the users of the new `WarpOp` results that were coming from the
// original `ForOp` to the corresponding new `ForOp` result.
for (auto [origIdx, newIdx] : forResultMapping)
rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
newForOp.getResult(newIdx), newForOp);
// Similarly, update any users of the `WarpOp` results that were not
// results of the `ForOp`.
for (auto [origIdx, newIdx] : nonForResultMapping)
rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
newWarpOp.getResult(newIdx));
// Remove the original `WarpOp` and `ForOp`, they should not have any uses
// at this point.
rewriter.eraseOp(forOp);
rewriter.eraseOp(warpOp);
// Update any users of escaping values that were forwarded to the
// inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
newForOp.walk([&](Operation *op) {
Expand Down
19 changes: 19 additions & 0 deletions mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1925,3 +1925,22 @@ func.func @warp_scf_if_distribute(%pred : i1) {
// CHECK-PROP: "some_use"(%[[IF_YIELD_DIST]]) : (vector<1xf32>) -> ()
// CHECK-PROP: return
// CHECK-PROP: }

// -----
func.func @dedup_unused_result(%laneid : index) -> (vector<1xf32>) {
%r:3 = gpu.warp_execute_on_lane_0(%laneid)[32] ->
(vector<1xf32>, vector<2xf32>, vector<1xf32>) {
%2 = "some_def"() : () -> (vector<32xf32>)
%3 = "some_def"() : () -> (vector<64xf32>)
gpu.yield %2, %3, %2 : vector<32xf32>, vector<64xf32>, vector<32xf32>
}
%r0 = "some_use"(%r#2, %r#2) : (vector<1xf32>, vector<1xf32>) -> (vector<1xf32>)
return %r0 : vector<1xf32>
}

// CHECK-PROP: func @dedup_unused_result
// CHECK-PROP: %[[R:.*]] = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<1xf32>)
// CHECK-PROP: %[[Y0:.*]] = "some_def"() : () -> vector<32xf32>
// CHECK-PROP: %[[Y1:.*]] = "some_def"() : () -> vector<64xf32>
// CHECK-PROP: gpu.yield %[[Y0]] : vector<32xf32>
// CHECK-PROP: "some_use"(%[[R]], %[[R]]) : (vector<1xf32>, vector<1xf32>) -> vector<1xf32>