Skip to content

Commit 0a71fd1

Browse files
authored
[MLIR][Vector] Improve warp distribution robustness (#161647)
1 parent 0e6557d commit 0a71fd1

File tree

2 files changed

+41
-40
lines changed

2 files changed

+41
-40
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 22 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -928,17 +928,20 @@ struct WarpOpDeadResult : public WarpDistributionPattern {
928928
// Some values may be yielded multiple times and correspond to multiple
929929
// results. Deduplicating occurs by taking each result with its matching
930930
// yielded value, and:
931-
// 1. recording the unique first position at which the value is yielded.
931+
// 1. recording the unique first position at which the value with uses is
932+
// yielded.
932933
// 2. recording for the result, the first position at which the dedup'ed
933934
// value is yielded.
934935
// 3. skipping from the new result types / new yielded values any result
935936
// that has no use or whose yielded value has already been seen.
936937
for (OpResult result : warpOp.getResults()) {
938+
if (result.use_empty())
939+
continue;
937940
Value yieldOperand = yield.getOperand(result.getResultNumber());
938941
auto it = dedupYieldOperandPositionMap.insert(
939942
std::make_pair(yieldOperand, newResultTypes.size()));
940943
dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
941-
if (result.use_empty() || !it.second)
944+
if (!it.second)
942945
continue;
943946
newResultTypes.push_back(result.getType());
944947
newYieldValues.push_back(yieldOperand);
@@ -1843,16 +1846,16 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
18431846
newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
18441847
escapingValueDistTypesElse.end());
18451848

1846-
llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx;
18471849
for (auto [idx, val] :
18481850
llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
1849-
origToNewYieldIdx[idx] = newWarpOpYieldValues.size();
18501851
newWarpOpYieldValues.push_back(val);
18511852
newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
18521853
}
1853-
// Create the new `WarpOp` with the updated yield values and types.
1854-
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
1855-
rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
1854+
// Replace the old `WarpOp` with the new one that has additional yield
1855+
// values and types.
1856+
SmallVector<size_t> newIndices;
1857+
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1858+
rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
18561859
// `ifOp` returns the result of the inner warp op.
18571860
SmallVector<Type> newIfOpDistResTypes;
18581861
for (auto [i, res] : llvm::enumerate(ifOp.getResults())) {
@@ -1870,8 +1873,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
18701873
OpBuilder::InsertionGuard g(rewriter);
18711874
rewriter.setInsertionPointAfter(newWarpOp);
18721875
auto newIfOp = scf::IfOp::create(
1873-
rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0),
1874-
static_cast<bool>(ifOp.thenBlock()),
1876+
rewriter, ifOp.getLoc(), newIfOpDistResTypes,
1877+
newWarpOp.getResult(newIndices[0]), static_cast<bool>(ifOp.thenBlock()),
18751878
static_cast<bool>(ifOp.elseBlock()));
18761879
auto encloseRegionInWarpOp =
18771880
[&](Block *oldIfBranch, Block *newIfBranch,
@@ -1888,7 +1891,7 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
18881891
for (size_t i = 0; i < escapingValues.size();
18891892
++i, ++warpResRangeStart) {
18901893
innerWarpInputVals.push_back(
1891-
newWarpOp.getResult(warpResRangeStart));
1894+
newWarpOp.getResult(newIndices[warpResRangeStart]));
18921895
escapeValToBlockArgIndex[escapingValues[i]] =
18931896
innerWarpInputTypes.size();
18941897
innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
@@ -1936,17 +1939,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
19361939
// Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
19371940
// result.
19381941
for (auto [origIdx, newIdx] : ifResultMapping)
1939-
rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
1942+
rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
19401943
newIfOp.getResult(newIdx), newIfOp);
1941-
// Similarly, update any users of the `WarpOp` results that were not
1942-
// results of the `IfOp`.
1943-
for (auto [origIdx, newIdx] : origToNewYieldIdx)
1944-
rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
1945-
newWarpOp.getResult(newIdx));
1946-
// Remove the original `WarpOp` and `IfOp`, they should not have any uses
1947-
// at this point.
1948-
rewriter.eraseOp(ifOp);
1949-
rewriter.eraseOp(warpOp);
19501944
return success();
19511945
}
19521946

@@ -2065,19 +2059,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
20652059
escapingValueDistTypes.begin(),
20662060
escapingValueDistTypes.end());
20672061
// Next, we insert all non-`ForOp` yielded values and their distributed
2068-
// types. We also create a mapping between the non-`ForOp` yielded value
2069-
// index and the corresponding new `WarpOp` yield value index (needed to
2070-
// update users later).
2071-
llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping;
2062+
// types.
20722063
for (auto [i, v] :
20732064
llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
2074-
nonForResultMapping[i] = newWarpOpYieldValues.size();
20752065
newWarpOpYieldValues.push_back(v);
20762066
newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
20772067
}
20782068
// Create the new `WarpOp` with the updated yield values and types.
2079-
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
2080-
rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
2069+
SmallVector<size_t> newIndices;
2070+
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2071+
rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
20812072

20822073
// Next, we create a new `ForOp` with the init args yielded by the new
20832074
// `WarpOp`.
@@ -2086,7 +2077,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
20862077
// escaping values in the new `WarpOp`.
20872078
SmallVector<Value> newForOpOperands;
20882079
for (size_t i = 0; i < escapingValuesStartIdx; ++i)
2089-
newForOpOperands.push_back(newWarpOp.getResult(i));
2080+
newForOpOperands.push_back(newWarpOp.getResult(newIndices[i]));
20902081

20912082
// Create a new `ForOp` outside the new `WarpOp` region.
20922083
OpBuilder::InsertionGuard g(rewriter);
@@ -2110,7 +2101,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
21102101
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
21112102
for (size_t i = escapingValuesStartIdx;
21122103
i < escapingValuesStartIdx + escapingValues.size(); ++i) {
2113-
innerWarpInput.push_back(newWarpOp.getResult(i));
2104+
innerWarpInput.push_back(newWarpOp.getResult(newIndices[i]));
21142105
argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
21152106
innerWarpInputType.size();
21162107
innerWarpInputType.push_back(
@@ -2146,20 +2137,11 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
21462137
if (!innerWarp.getResults().empty())
21472138
scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults());
21482139

2149-
// Update the users of original `WarpOp` results that were coming from the
2140+
// Update the users of the new `WarpOp` results that were coming from the
21502141
// original `ForOp` to the corresponding new `ForOp` result.
21512142
for (auto [origIdx, newIdx] : forResultMapping)
2152-
rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
2143+
rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
21532144
newForOp.getResult(newIdx), newForOp);
2154-
// Similarly, update any users of the `WarpOp` results that were not
2155-
// results of the `ForOp`.
2156-
for (auto [origIdx, newIdx] : nonForResultMapping)
2157-
rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
2158-
newWarpOp.getResult(newIdx));
2159-
// Remove the original `WarpOp` and `ForOp`, they should not have any uses
2160-
// at this point.
2161-
rewriter.eraseOp(forOp);
2162-
rewriter.eraseOp(warpOp);
21632145
// Update any users of escaping values that were forwarded to the
21642146
// inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
21652147
newForOp.walk([&](Operation *op) {

mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1925,3 +1925,22 @@ func.func @warp_scf_if_distribute(%pred : i1) {
19251925
// CHECK-PROP: "some_use"(%[[IF_YIELD_DIST]]) : (vector<1xf32>) -> ()
19261926
// CHECK-PROP: return
19271927
// CHECK-PROP: }
1928+
1929+
// -----
1930+
func.func @dedup_unused_result(%laneid : index) -> (vector<1xf32>) {
1931+
%r:3 = gpu.warp_execute_on_lane_0(%laneid)[32] ->
1932+
(vector<1xf32>, vector<2xf32>, vector<1xf32>) {
1933+
%2 = "some_def"() : () -> (vector<32xf32>)
1934+
%3 = "some_def"() : () -> (vector<64xf32>)
1935+
gpu.yield %2, %3, %2 : vector<32xf32>, vector<64xf32>, vector<32xf32>
1936+
}
1937+
%r0 = "some_use"(%r#2, %r#2) : (vector<1xf32>, vector<1xf32>) -> (vector<1xf32>)
1938+
return %r0 : vector<1xf32>
1939+
}
1940+
1941+
// CHECK-PROP: func @dedup_unused_result
1942+
// CHECK-PROP: %[[R:.*]] = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<1xf32>)
1943+
// CHECK-PROP: %[[Y0:.*]] = "some_def"() : () -> vector<32xf32>
1944+
// CHECK-PROP: %[[Y1:.*]] = "some_def"() : () -> vector<64xf32>
1945+
// CHECK-PROP: gpu.yield %[[Y0]] : vector<32xf32>
1946+
// CHECK-PROP: "some_use"(%[[R]], %[[R]]) : (vector<1xf32>, vector<1xf32>) -> vector<1xf32>

0 commit comments

Comments
 (0)