@@ -928,17 +928,20 @@ struct WarpOpDeadResult : public WarpDistributionPattern {
928928 // Some values may be yielded multiple times and correspond to multiple
929929 // results. Deduplicating occurs by taking each result with its matching
930930 // yielded value, and:
931- // 1. recording the unique first position at which the value is yielded.
931+ // 1. recording the unique first position at which the value with uses is
932+ // yielded.
932933 // 2. recording for the result, the first position at which the dedup'ed
933934 // value is yielded.
934935 // 3. skipping from the new result types / new yielded values any result
935936 // that has no use or whose yielded value has already been seen.
936937 for (OpResult result : warpOp.getResults ()) {
938+ if (result.use_empty ())
939+ continue ;
937940 Value yieldOperand = yield.getOperand (result.getResultNumber ());
938941 auto it = dedupYieldOperandPositionMap.insert (
939942 std::make_pair (yieldOperand, newResultTypes.size ()));
940943 dedupResultPositionMap.insert (std::make_pair (result, it.first ->second ));
941- if (result. use_empty () || !it.second )
944+ if (!it.second )
942945 continue ;
943946 newResultTypes.push_back (result.getType ());
944947 newYieldValues.push_back (yieldOperand);
@@ -1843,16 +1846,16 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
18431846 newWarpOpDistTypes.append (escapingValueDistTypesElse.begin (),
18441847 escapingValueDistTypesElse.end ());
18451848
1846- llvm::SmallDenseMap<unsigned , unsigned > origToNewYieldIdx;
18471849 for (auto [idx, val] :
18481850 llvm::zip_equal (nonIfYieldIndices, nonIfYieldValues)) {
1849- origToNewYieldIdx[idx] = newWarpOpYieldValues.size ();
18501851 newWarpOpYieldValues.push_back (val);
18511852 newWarpOpDistTypes.push_back (warpOp.getResult (idx).getType ());
18521853 }
1853- // Create the new `WarpOp` with the updated yield values and types.
1854- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns (
1855- rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
1854+ // Replace the old `WarpOp` with the new one that has additional yield
1855+ // values and types.
1856+ SmallVector<size_t > newIndices;
1857+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
1858+ rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
18561859 // `ifOp` returns the result of the inner warp op.
18571860 SmallVector<Type> newIfOpDistResTypes;
18581861 for (auto [i, res] : llvm::enumerate (ifOp.getResults ())) {
@@ -1870,8 +1873,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
18701873 OpBuilder::InsertionGuard g (rewriter);
18711874 rewriter.setInsertionPointAfter (newWarpOp);
18721875 auto newIfOp = scf::IfOp::create (
1873- rewriter, ifOp.getLoc (), newIfOpDistResTypes, newWarpOp. getResult ( 0 ),
1874- static_cast <bool >(ifOp.thenBlock ()),
1876+ rewriter, ifOp.getLoc (), newIfOpDistResTypes,
1877+ newWarpOp. getResult (newIndices[ 0 ]), static_cast <bool >(ifOp.thenBlock ()),
18751878 static_cast <bool >(ifOp.elseBlock ()));
18761879 auto encloseRegionInWarpOp =
18771880 [&](Block *oldIfBranch, Block *newIfBranch,
@@ -1888,7 +1891,7 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
18881891 for (size_t i = 0 ; i < escapingValues.size ();
18891892 ++i, ++warpResRangeStart) {
18901893 innerWarpInputVals.push_back (
1891- newWarpOp.getResult (warpResRangeStart));
1894+ newWarpOp.getResult (newIndices[ warpResRangeStart] ));
18921895 escapeValToBlockArgIndex[escapingValues[i]] =
18931896 innerWarpInputTypes.size ();
18941897 innerWarpInputTypes.push_back (escapingValueInputTypes[i]);
@@ -1936,17 +1939,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
19361939 // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
19371940 // result.
19381941 for (auto [origIdx, newIdx] : ifResultMapping)
1939- rewriter.replaceAllUsesExcept (warpOp .getResult (origIdx),
1942+ rewriter.replaceAllUsesExcept (newWarpOp .getResult (origIdx),
19401943 newIfOp.getResult (newIdx), newIfOp);
1941- // Similarly, update any users of the `WarpOp` results that were not
1942- // results of the `IfOp`.
1943- for (auto [origIdx, newIdx] : origToNewYieldIdx)
1944- rewriter.replaceAllUsesWith (warpOp.getResult (origIdx),
1945- newWarpOp.getResult (newIdx));
1946- // Remove the original `WarpOp` and `IfOp`, they should not have any uses
1947- // at this point.
1948- rewriter.eraseOp (ifOp);
1949- rewriter.eraseOp (warpOp);
19501944 return success ();
19511945 }
19521946
@@ -2065,19 +2059,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
20652059 escapingValueDistTypes.begin (),
20662060 escapingValueDistTypes.end ());
20672061 // Next, we insert all non-`ForOp` yielded values and their distributed
2068- // types. We also create a mapping between the non-`ForOp` yielded value
2069- // index and the corresponding new `WarpOp` yield value index (needed to
2070- // update users later).
2071- llvm::SmallDenseMap<unsigned , unsigned > nonForResultMapping;
2062+ // types.
20722063 for (auto [i, v] :
20732064 llvm::zip_equal (nonForResultIndices, nonForYieldedValues)) {
2074- nonForResultMapping[i] = newWarpOpYieldValues.size ();
20752065 newWarpOpYieldValues.push_back (v);
20762066 newWarpOpDistTypes.push_back (warpOp.getResult (i).getType ());
20772067 }
20782068 // Create the new `WarpOp` with the updated yield values and types.
2079- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns (
2080- rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
2069+ SmallVector<size_t > newIndices;
2070+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
2071+ rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
20812072
20822073 // Next, we create a new `ForOp` with the init args yielded by the new
20832074 // `WarpOp`.
@@ -2086,7 +2077,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
20862077 // escaping values in the new `WarpOp`.
20872078 SmallVector<Value> newForOpOperands;
20882079 for (size_t i = 0 ; i < escapingValuesStartIdx; ++i)
2089- newForOpOperands.push_back (newWarpOp.getResult (i ));
2080+ newForOpOperands.push_back (newWarpOp.getResult (newIndices[i] ));
20902081
20912082 // Create a new `ForOp` outside the new `WarpOp` region.
20922083 OpBuilder::InsertionGuard g (rewriter);
@@ -2110,7 +2101,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
21102101 llvm::SmallDenseMap<Value, int64_t > argIndexMapping;
21112102 for (size_t i = escapingValuesStartIdx;
21122103 i < escapingValuesStartIdx + escapingValues.size (); ++i) {
2113- innerWarpInput.push_back (newWarpOp.getResult (i ));
2104+ innerWarpInput.push_back (newWarpOp.getResult (newIndices[i] ));
21142105 argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
21152106 innerWarpInputType.size ();
21162107 innerWarpInputType.push_back (
@@ -2146,20 +2137,11 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
21462137 if (!innerWarp.getResults ().empty ())
21472138 scf::YieldOp::create (rewriter, forOp.getLoc (), innerWarp.getResults ());
21482139
2149- // Update the users of original `WarpOp` results that were coming from the
2140+ // Update the users of the new `WarpOp` results that were coming from the
21502141 // original `ForOp` to the corresponding new `ForOp` result.
21512142 for (auto [origIdx, newIdx] : forResultMapping)
2152- rewriter.replaceAllUsesExcept (warpOp .getResult (origIdx),
2143+ rewriter.replaceAllUsesExcept (newWarpOp .getResult (origIdx),
21532144 newForOp.getResult (newIdx), newForOp);
2154- // Similarly, update any users of the `WarpOp` results that were not
2155- // results of the `ForOp`.
2156- for (auto [origIdx, newIdx] : nonForResultMapping)
2157- rewriter.replaceAllUsesWith (warpOp.getResult (origIdx),
2158- newWarpOp.getResult (newIdx));
2159- // Remove the original `WarpOp` and `ForOp`, they should not have any uses
2160- // at this point.
2161- rewriter.eraseOp (forOp);
2162- rewriter.eraseOp (warpOp);
21632145 // Update any users of escaping values that were forwarded to the
21642146 // inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
21652147 newForOp.walk ([&](Operation *op) {
0 commit comments