@@ -1796,26 +1796,34 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17961796 yieldedValuesFromWarpOp.insert (yieldedValuesFromWarpOp.end (),
17971797 escapingValues.begin (),
17981798 escapingValues.end ());
1799-
1800- SmallVector<size_t > newRetIndices;
1801- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
1802- rewriter, warpOp, yieldedValuesFromWarpOp, distTypes, newRetIndices);
1803- yield = cast<gpu::YieldOp>(
1804- newWarpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
1805-
1806- SmallVector<Value> newOperands;
1799+ // record result mapping.
18071800 SmallVector<unsigned > resultIdx;
1808- // Collect the new init args coming from the new warp op.
1809- for (size_t i = 0 ; i < forOp.getInitArgs ().size (); ++i)
1810- newOperands.push_back (newWarpOp.getResult (newRetIndices[i]));
18111801 for (OpOperand &yieldOperand : yield->getOpOperands ()) {
18121802 if (yieldOperand.get ().getDefiningOp () != forOp.getOperation ())
18131803 continue ;
18141804 OpResult forResult = cast<OpResult>(yieldOperand.get ());
18151805 resultIdx.push_back (forResult.getResultNumber ());
1816- yieldOperand.set (forOp.getInitArgs ()[forResult.getResultNumber ()]);
1806+ // yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
18171807 }
18181808
1809+ // SmallVector<size_t> newRetIndices;
1810+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns (
1811+ rewriter, warpOp, yieldedValuesFromWarpOp, distTypes);
1812+ yield = cast<gpu::YieldOp>(
1813+ newWarpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
1814+
1815+ SmallVector<Value> newOperands;
1816+ // Collect the new init args coming from the new warp op.
1817+ for (size_t i = 0 ; i < forOp.getInitArgs ().size (); ++i)
1818+ newOperands.push_back (newWarpOp.getResult (i));
1819+ // for (OpOperand &yieldOperand : yield->getOpOperands()) {
1820+ // if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
1821+ // continue;
1822+ // OpResult forResult = cast<OpResult>(yieldOperand.get());
1823+ // resultIdx.push_back(forResult.getResultNumber());
1824+ // yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
1825+ // }
1826+
18191827 OpBuilder::InsertionGuard g (rewriter);
18201828 rewriter.setInsertionPointAfter (newWarpOp);
18211829
@@ -1831,7 +1839,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18311839 SmallVector<Type> warpInputType (forOp.getResultTypes ().begin (),
18321840 forOp.getResultTypes ().end ());
18331841 llvm::SmallDenseMap<Value, int64_t > argIndexMapping;
1834- for (size_t i = forOp.getInitArgs ().size (); i < newRetIndices.size (); ++i) {
1842+ for (size_t i = forOp.getInitArgs ().size (); i < newWarpOp->getNumResults ();
1843+ ++i) {
18351844 warpInput.push_back (newWarpOp.getResult (i));
18361845 argIndexMapping[escapingValues[i]] = warpInputType.size ();
18371846 warpInputType.push_back (inputTypes[i]);
@@ -1870,12 +1879,13 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18701879 llvm::errs () << idx << " " ;
18711880 llvm::errs () << " \n " ;
18721881 for (const auto &res : llvm::enumerate (resultIdx)) {
1873- rewriter.replaceAllUsesWith (newWarpOp .getResult (res.value ()),
1874- newForOp.getResult (res.index ()));
1882+ rewriter.replaceAllUsesExcept (warpOp .getResult (res.value ()),
1883+ newForOp.getResult (res.index ()), newForOp );
18751884 // newForOp->setOperand(res.index() + 3,
18761885 // newWarpOp.getResult(res.value()));
18771886 }
18781887 rewriter.eraseOp (forOp);
1888+ rewriter.eraseOp (warpOp);
18791889 newForOp.walk ([&](Operation *op) {
18801890 for (OpOperand &operand : op->getOpOperands ()) {
18811891 auto it = argIndexMapping.find (operand.get ());
0 commit comments