@@ -1704,19 +1704,18 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17041704      : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
17051705  LogicalResult matchAndRewrite (WarpExecuteOnLane0Op warpOp,
17061706                                PatternRewriter &rewriter) const  override  {
1707-     auto  yield  = cast<gpu::YieldOp>(
1707+     auto  warpOpYield  = cast<gpu::YieldOp>(
17081708        warpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
1709-     //  Only pick up forOp  if it is the last op in the region.
1710-     Operation *lastNode = yield ->getPrevNode ();
1709+     //  Only pick up `ForOp`  if it is the last op in the region.
1710+     Operation *lastNode = warpOpYield ->getPrevNode ();
17111711    auto  forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
17121712    if  (!forOp)
17131713      return  failure ();
1714-     //  Collect Values that come from the warp op but are outside the forOp.
1715-     //  Those Value needs to be returned by the original warpOp and passed to
1716-     //  the new op.
1714+     //  Collect Values that come from the `WarpOp` but are outside the `ForOp`.
1715+     //  Those Values need to be returned by the new warp op.
17171716    llvm::SmallSetVector<Value, 32 > escapingValues;
1718-     SmallVector<Type> inputTypes ;
1719-     SmallVector<Type> distTypes ;
1717+     SmallVector<Type> escapingValueInputTypes ;
1718+     SmallVector<Type> escapingValueDistTypes ;
17201719    mlir::visitUsedValuesDefinedAbove (
17211720        forOp.getBodyRegion (), [&](OpOperand *operand) {
17221721          Operation *parent = operand->get ().getParentRegion ()->getParentOp ();
@@ -1728,81 +1727,168 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17281727              AffineMap map = distributionMapFn (operand->get ());
17291728              distType = getDistributedType (vecType, map, warpOp.getWarpSize ());
17301729            }
1731-             inputTypes .push_back (operand->get ().getType ());
1732-             distTypes .push_back (distType);
1730+             escapingValueInputTypes .push_back (operand->get ().getType ());
1731+             escapingValueDistTypes .push_back (distType);
17331732          }
17341733        });
17351734
1736-     if  (llvm::is_contained (distTypes , Type{}))
1735+     if  (llvm::is_contained (escapingValueDistTypes , Type{}))
17371736      return  failure ();
1738- 
1739-     SmallVector<size_t > newRetIndices;
1740-     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
1741-         rewriter, warpOp, escapingValues.getArrayRef (), distTypes,
1742-         newRetIndices);
1743-     yield = cast<gpu::YieldOp>(
1744-         newWarpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
1745- 
1746-     SmallVector<Value> newOperands;
1747-     SmallVector<unsigned > resultIdx;
1748-     //  Collect all the outputs coming from the forOp.
1749-     for  (OpOperand &yieldOperand : yield->getOpOperands ()) {
1750-       if  (yieldOperand.get ().getDefiningOp () != forOp.getOperation ())
1737+     //  `WarpOp` can yield two types of values:
1738+     //  1. Values that are not results of the `ForOp`:
1739+     //     These values must also be yielded by the new `WarpOp`. Also, we need
1740+     //     to record the index mapping for these values to replace them later.
1741+     //  2. Values that are results of the `ForOp`:
1742+     //     In this case, we record the index mapping between the `WarpOp` result
1743+     //     index and matching `ForOp` result index.
1744+     //  Additionally, we keep track of the distributed types for all `ForOp`
1745+     //  vector results.
1746+     SmallVector<Value> nonForYieldedValues;
1747+     SmallVector<unsigned > nonForResultIndices;
1748+     llvm::SmallDenseMap<unsigned , unsigned > forResultMapping;
1749+     llvm::SmallDenseMap<unsigned , VectorType> forResultDistTypes;
1750+     for  (OpOperand &yieldOperand : warpOpYield->getOpOperands ()) {
1751+       //  Yielded value is not a result of the forOp.
1752+       if  (yieldOperand.get ().getDefiningOp () != forOp.getOperation ()) {
1753+         nonForYieldedValues.push_back (yieldOperand.get ());
1754+         nonForResultIndices.push_back (yieldOperand.getOperandNumber ());
17511755        continue ;
1752-       auto  forResult = cast<OpResult>(yieldOperand.get ());
1753-       newOperands.push_back (
1754-           newWarpOp.getResult (yieldOperand.getOperandNumber ()));
1755-       yieldOperand.set (forOp.getInitArgs ()[forResult.getResultNumber ()]);
1756-       resultIdx.push_back (yieldOperand.getOperandNumber ());
1756+       }
1757+       OpResult forResult = cast<OpResult>(yieldOperand.get ());
1758+       unsigned  int  forResultNumber = forResult.getResultNumber ();
1759+       forResultMapping[yieldOperand.getOperandNumber ()] = forResultNumber;
1760+       //  If this `ForOp` result is vector type and it is yielded by the
1761+       //  `WarpOp`, we keep track the distributed type for this result.
1762+       if  (!isa<VectorType>(forResult.getType ()))
1763+         continue ;
1764+       VectorType distType = cast<VectorType>(
1765+           warpOp.getResult (yieldOperand.getOperandNumber ()).getType ());
1766+       forResultDistTypes[forResultNumber] = distType;
17571767    }
17581768
1769+     //  Newly created `WarpOp` will yield values in following order:
1770+     //  1. All init args of the `ForOp`.
1771+     //  2. All escaping values.
1772+     //  3. All non-`ForOp` yielded values.
1773+     SmallVector<Value> newWarpOpYieldValues;
1774+     SmallVector<Type> newWarpOpDistTypes;
1775+     for  (auto  [i, initArg] : llvm::enumerate (forOp.getInitArgs ())) {
1776+       newWarpOpYieldValues.push_back (initArg);
1777+       //  Compute the distributed type for this init arg.
1778+       Type distType = initArg.getType ();
1779+       if  (auto  vecType = dyn_cast<VectorType>(distType)) {
1780+         //  If the `ForOp` result corresponds to this init arg is already yielded
1781+         //  we can get the distributed type from `forResultDistTypes` map.
1782+         //  Otherwise, we compute it using distributionMapFn.
1783+         AffineMap map = distributionMapFn (initArg);
1784+         distType = forResultDistTypes.count (i)
1785+                        ? forResultDistTypes[i]
1786+                        : getDistributedType (vecType, map, warpOp.getWarpSize ());
1787+       }
1788+       newWarpOpDistTypes.push_back (distType);
1789+     }
1790+     //  Insert escaping values and their distributed types.
1791+     newWarpOpYieldValues.insert (newWarpOpYieldValues.end (),
1792+                                 escapingValues.begin (), escapingValues.end ());
1793+     newWarpOpDistTypes.insert (newWarpOpDistTypes.end (),
1794+                               escapingValueDistTypes.begin (),
1795+                               escapingValueDistTypes.end ());
1796+     //  Next, we insert all non-`ForOp` yielded values and their distributed
1797+     //  types. We also create a mapping between the non-`ForOp` yielded value
1798+     //  index and the corresponding new `WarpOp` yield value index (needed to
1799+     //  update users later).
1800+     llvm::SmallDenseMap<unsigned , unsigned > nonForResultMapping;
1801+     for  (auto  [i, v] :
1802+          llvm::zip_equal (nonForResultIndices, nonForYieldedValues)) {
1803+       nonForResultMapping[i] = newWarpOpYieldValues.size ();
1804+       newWarpOpYieldValues.push_back (v);
1805+       newWarpOpDistTypes.push_back (warpOp.getResult (i).getType ());
1806+     }
1807+     //  Create the new `WarpOp` with the updated yield values and types.
1808+     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns (
1809+         rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
1810+ 
1811+     //  Next, we create a new `ForOp` with the init args yielded by the new
1812+     //  `WarpOp`.
1813+     const  unsigned  escapingValuesStartIdx =
1814+         forOp.getInitArgs ().size (); //  `ForOp` init args are positioned before
1815+                                     //  escaping values in the new `WarpOp`.
1816+     SmallVector<Value> newForOpOperands;
1817+     for  (size_t  i = 0 ; i < escapingValuesStartIdx; ++i)
1818+       newForOpOperands.push_back (newWarpOp.getResult (i));
1819+ 
1820+     //  Create a new `ForOp` outside the new `WarpOp` region.
17591821    OpBuilder::InsertionGuard g (rewriter);
17601822    rewriter.setInsertionPointAfter (newWarpOp);
1761- 
1762-     //  Create a new for op outside the region with a WarpExecuteOnLane0Op
1763-     //  region inside.
17641823    auto  newForOp = rewriter.create <scf::ForOp>(
17651824        forOp.getLoc (), forOp.getLowerBound (), forOp.getUpperBound (),
1766-         forOp.getStep (), newOperands);
1825+         forOp.getStep (), newForOpOperands);
1826+     //  Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
1827+     //  newly created `ForOp`. This `WarpOp` will contain all ops that were
1828+     //  contained within the original `ForOp` body.
17671829    rewriter.setInsertionPointToStart (newForOp.getBody ());
17681830
1769-     SmallVector<Value> warpInput (newForOp.getRegionIterArgs ().begin (),
1770-                                  newForOp.getRegionIterArgs ().end ());
1771-     SmallVector<Type> warpInputType (forOp.getResultTypes ().begin (),
1772-                                     forOp.getResultTypes ().end ());
1831+     SmallVector<Value> innerWarpInput (newForOp.getRegionIterArgs ().begin (),
1832+                                       newForOp.getRegionIterArgs ().end ());
1833+     SmallVector<Type> innerWarpInputType (forOp.getResultTypes ().begin (),
1834+                                          forOp.getResultTypes ().end ());
1835+     //  Escaping values are forwarded to the inner `WarpOp` as its (additional)
1836+     //  arguments. We keep track of the mapping between these values and their
1837+     //  argument index in the inner `WarpOp` (to replace users later).
17731838    llvm::SmallDenseMap<Value, int64_t > argIndexMapping;
1774-     for  (auto  [i, retIdx] : llvm::enumerate (newRetIndices)) {
1775-       warpInput.push_back (newWarpOp.getResult (retIdx));
1776-       argIndexMapping[escapingValues[i]] = warpInputType.size ();
1777-       warpInputType.push_back (inputTypes[i]);
1839+     for  (size_t  i = escapingValuesStartIdx;
1840+          i < escapingValuesStartIdx + escapingValues.size (); ++i) {
1841+       innerWarpInput.push_back (newWarpOp.getResult (i));
1842+       argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
1843+           innerWarpInputType.size ();
1844+       innerWarpInputType.push_back (
1845+           escapingValueInputTypes[i - escapingValuesStartIdx]);
17781846    }
1847+     //  Create the inner `WarpOp` with the new input values and types.
17791848    auto  innerWarp = rewriter.create <WarpExecuteOnLane0Op>(
17801849        newWarpOp.getLoc (), newForOp.getResultTypes (), newWarpOp.getLaneid (),
1781-         newWarpOp.getWarpSize (), warpInput, warpInputType );
1850+         newWarpOp.getWarpSize (), innerWarpInput, innerWarpInputType );
17821851
1852+     //  Inline the `ForOp` body into the inner `WarpOp` body.
17831853    SmallVector<Value> argMapping;
17841854    argMapping.push_back (newForOp.getInductionVar ());
1785-     for  (Value args : innerWarp.getBody ()->getArguments ()) { 
1855+     for  (Value args : innerWarp.getBody ()->getArguments ())
17861856      argMapping.push_back (args);
1787-     } 
1857+ 
17881858    argMapping.resize (forOp.getBody ()->getNumArguments ());
17891859    SmallVector<Value> yieldOperands;
17901860    for  (Value operand : forOp.getBody ()->getTerminator ()->getOperands ())
17911861      yieldOperands.push_back (operand);
1862+ 
17921863    rewriter.eraseOp (forOp.getBody ()->getTerminator ());
17931864    rewriter.mergeBlocks (forOp.getBody (), innerWarp.getBody (), argMapping);
1865+ 
1866+     //  Insert a gpu `YieldOp` at the end of the inner `WarpOp` body that yields
1867+     //  original `ForOp` results.
17941868    rewriter.setInsertionPointToEnd (innerWarp.getBody ());
17951869    rewriter.create <gpu::YieldOp>(innerWarp.getLoc (), yieldOperands);
17961870    rewriter.setInsertionPointAfter (innerWarp);
1871+     //  Insert a scf.yield op at the end of the new `ForOp` body that yields
1872+     //  the inner `WarpOp` results.
17971873    if  (!innerWarp.getResults ().empty ())
17981874      rewriter.create <scf::YieldOp>(forOp.getLoc (), innerWarp.getResults ());
1875+ 
1876+     //  Update the users of original `WarpOp` results that were coming from the
1877+     //  original `ForOp` to the corresponding new `ForOp` result.
1878+     for  (auto  [origIdx, newIdx] : forResultMapping)
1879+       rewriter.replaceAllUsesExcept (warpOp.getResult (origIdx),
1880+                                     newForOp.getResult (newIdx), newForOp);
1881+     //  Similarly, update any users of the `WarpOp` results that were not
1882+     //  results of the `ForOp`.
1883+     for  (auto  [origIdx, newIdx] : nonForResultMapping)
1884+       rewriter.replaceAllUsesWith (warpOp.getResult (origIdx),
1885+                                   newWarpOp.getResult (newIdx));
1886+     //  Remove the original `WarpOp` and `ForOp`, they should not have any uses
1887+     //  at this point.
17991888    rewriter.eraseOp (forOp);
1800-     //  Replace the warpOp result coming from the original ForOp.
1801-     for  (const  auto  &res : llvm::enumerate (resultIdx)) {
1802-       rewriter.replaceAllUsesWith (newWarpOp.getResult (res.value ()),
1803-                                   newForOp.getResult (res.index ()));
1804-       newForOp->setOperand (res.index () + 3 , newWarpOp.getResult (res.value ()));
1805-     }
1889+     rewriter.eraseOp (warpOp);
1890+     //  Update any users of escaping values that were forwarded to the
1891+     //  inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
18061892    newForOp.walk ([&](Operation *op) {
18071893      for  (OpOperand &operand : op->getOpOperands ()) {
18081894        auto  it = argIndexMapping.find (operand.get ());
@@ -1812,7 +1898,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18121898      }
18131899    });
18141900
1815-     //  Finally, hoist out any now uniform code from the inner warp op .
1901+     //  Finally, hoist out any now uniform code from the inner `WarpOp` .
18161902    mlir::vector::moveScalarUniformCode (innerWarp);
18171903    return  success ();
18181904  }
0 commit comments