@@ -934,11 +934,13 @@ struct WarpOpDeadResult : public WarpDistributionPattern {
934934 // 3. skipping from the new result types / new yielded values any result
935935 // that has no use or whose yielded value has already been seen.
936936 for (OpResult result : warpOp.getResults ()) {
937+ if (result.use_empty ())
938+ continue ;
937939 Value yieldOperand = yield.getOperand (result.getResultNumber ());
938940 auto it = dedupYieldOperandPositionMap.insert (
939941 std::make_pair (yieldOperand, newResultTypes.size ()));
940942 dedupResultPositionMap.insert (std::make_pair (result, it.first ->second ));
941- if (result. use_empty () || !it.second )
943+ if (!it.second )
942944 continue ;
943945 newResultTypes.push_back (result.getType ());
944946 newYieldValues.push_back (yieldOperand);
@@ -1843,16 +1845,16 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
18431845 newWarpOpDistTypes.append (escapingValueDistTypesElse.begin (),
18441846 escapingValueDistTypesElse.end ());
18451847
1846- llvm::SmallDenseMap<unsigned , unsigned > origToNewYieldIdx;
18471848 for (auto [idx, val] :
18481849 llvm::zip_equal (nonIfYieldIndices, nonIfYieldValues)) {
1849- origToNewYieldIdx[idx] = newWarpOpYieldValues.size ();
18501850 newWarpOpYieldValues.push_back (val);
18511851 newWarpOpDistTypes.push_back (warpOp.getResult (idx).getType ());
18521852 }
1853- // Create the new `WarpOp` with the updated yield values and types.
1854- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns (
1855- rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
1853+ // Replace the old `WarpOp` with the new one that has additional yield
1854+ // values and types.
1855+ SmallVector<size_t > newIndices;
1856+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
1857+ rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
18561858 // `ifOp` returns the result of the inner warp op.
18571859 SmallVector<Type> newIfOpDistResTypes;
18581860 for (auto [i, res] : llvm::enumerate (ifOp.getResults ())) {
@@ -1870,8 +1872,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
18701872 OpBuilder::InsertionGuard g (rewriter);
18711873 rewriter.setInsertionPointAfter (newWarpOp);
18721874 auto newIfOp = scf::IfOp::create (
1873- rewriter, ifOp.getLoc (), newIfOpDistResTypes, newWarpOp. getResult ( 0 ),
1874- static_cast <bool >(ifOp.thenBlock ()),
1875+ rewriter, ifOp.getLoc (), newIfOpDistResTypes,
1876+ newWarpOp. getResult (newIndices[ 0 ]), static_cast <bool >(ifOp.thenBlock ()),
18751877 static_cast <bool >(ifOp.elseBlock ()));
18761878 auto encloseRegionInWarpOp =
18771879 [&](Block *oldIfBranch, Block *newIfBranch,
@@ -1888,7 +1890,7 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
18881890 for (size_t i = 0 ; i < escapingValues.size ();
18891891 ++i, ++warpResRangeStart) {
18901892 innerWarpInputVals.push_back (
1891- newWarpOp.getResult (warpResRangeStart));
1893+ newWarpOp.getResult (newIndices[ warpResRangeStart] ));
18921894 escapeValToBlockArgIndex[escapingValues[i]] =
18931895 innerWarpInputTypes.size ();
18941896 innerWarpInputTypes.push_back (escapingValueInputTypes[i]);
@@ -1936,17 +1938,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
19361938 // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
19371939 // result.
19381940 for (auto [origIdx, newIdx] : ifResultMapping)
1939- rewriter.replaceAllUsesExcept (warpOp .getResult (origIdx),
1941+ rewriter.replaceAllUsesExcept (newWarpOp .getResult (origIdx),
19401942 newIfOp.getResult (newIdx), newIfOp);
1941- // Similarly, update any users of the `WarpOp` results that were not
1942- // results of the `IfOp`.
1943- for (auto [origIdx, newIdx] : origToNewYieldIdx)
1944- rewriter.replaceAllUsesWith (warpOp.getResult (origIdx),
1945- newWarpOp.getResult (newIdx));
1946- // Remove the original `WarpOp` and `IfOp`, they should not have any uses
1947- // at this point.
1948- rewriter.eraseOp (ifOp);
1949- rewriter.eraseOp (warpOp);
19501943 return success ();
19511944 }
19521945
@@ -2065,19 +2058,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
20652058 escapingValueDistTypes.begin (),
20662059 escapingValueDistTypes.end ());
20672060 // Next, we insert all non-`ForOp` yielded values and their distributed
2068- // types. We also create a mapping between the non-`ForOp` yielded value
2069- // index and the corresponding new `WarpOp` yield value index (needed to
2070- // update users later).
2071- llvm::SmallDenseMap<unsigned , unsigned > nonForResultMapping;
2061+ // types.
20722062 for (auto [i, v] :
20732063 llvm::zip_equal (nonForResultIndices, nonForYieldedValues)) {
2074- nonForResultMapping[i] = newWarpOpYieldValues.size ();
20752064 newWarpOpYieldValues.push_back (v);
20762065 newWarpOpDistTypes.push_back (warpOp.getResult (i).getType ());
20772066 }
20782067 // Create the new `WarpOp` with the updated yield values and types.
2079- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns (
2080- rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
2068+ SmallVector<size_t > newIndices;
2069+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
2070+ rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
20812071
20822072 // Next, we create a new `ForOp` with the init args yielded by the new
20832073 // `WarpOp`.
@@ -2086,7 +2076,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
20862076 // escaping values in the new `WarpOp`.
20872077 SmallVector<Value> newForOpOperands;
20882078 for (size_t i = 0 ; i < escapingValuesStartIdx; ++i)
2089- newForOpOperands.push_back (newWarpOp.getResult (i ));
2079+ newForOpOperands.push_back (newWarpOp.getResult (newIndices[i] ));
20902080
20912081 // Create a new `ForOp` outside the new `WarpOp` region.
20922082 OpBuilder::InsertionGuard g (rewriter);
@@ -2110,7 +2100,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
21102100 llvm::SmallDenseMap<Value, int64_t > argIndexMapping;
21112101 for (size_t i = escapingValuesStartIdx;
21122102 i < escapingValuesStartIdx + escapingValues.size (); ++i) {
2113- innerWarpInput.push_back (newWarpOp.getResult (i ));
2103+ innerWarpInput.push_back (newWarpOp.getResult (newIndices[i] ));
21142104 argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
21152105 innerWarpInputType.size ();
21162106 innerWarpInputType.push_back (
@@ -2146,20 +2136,11 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
21462136 if (!innerWarp.getResults ().empty ())
21472137 scf::YieldOp::create (rewriter, forOp.getLoc (), innerWarp.getResults ());
21482138
2149- // Update the users of original `WarpOp` results that were coming from the
2139+ // Update the users of the new `WarpOp` results that were coming from the
21502140 // original `ForOp` to the corresponding new `ForOp` result.
21512141 for (auto [origIdx, newIdx] : forResultMapping)
2152- rewriter.replaceAllUsesExcept (warpOp .getResult (origIdx),
2142+ rewriter.replaceAllUsesExcept (newWarpOp .getResult (origIdx),
21532143 newForOp.getResult (newIdx), newForOp);
2154- // Similarly, update any users of the `WarpOp` results that were not
2155- // results of the `ForOp`.
2156- for (auto [origIdx, newIdx] : nonForResultMapping)
2157- rewriter.replaceAllUsesWith (warpOp.getResult (origIdx),
2158- newWarpOp.getResult (newIdx));
2159- // Remove the original `WarpOp` and `ForOp`, they should not have any uses
2160- // at this point.
2161- rewriter.eraseOp (forOp);
2162- rewriter.eraseOp (warpOp);
21632144 // Update any users of escaping values that were forwarded to the
21642145 // inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
21652146 newForOp.walk ([&](Operation *op) {
0 commit comments