@@ -1751,13 +1751,13 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17511751 PatternRewriter &rewriter) const override {
17521752 auto newWarpOpYield = cast<gpu::YieldOp>(
17531753 warpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
1754- // Only pick up forOp if it is the last op in the region.
1754+ // Only pick up `ForOp` if it is the last op in the region.
17551755 Operation *lastNode = newWarpOpYield->getPrevNode ();
17561756 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
17571757 if (!forOp)
17581758 return failure ();
1759- // Collect Values that come from the warp op but are outside the forOp .
1760- // Those Value needs to be returned by the new warp op.
1759+ // Collect Values that come from the `WarpOp` but are outside the `ForOp` .
1760+ // Those Values need to be returned by the new warp op.
17611761 llvm::SmallSetVector<Value, 32 > escapingValues;
17621762 SmallVector<Type> escapingValueInputTypes;
17631763 SmallVector<Type> escapingValuedistTypes;
@@ -1779,16 +1779,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17791779
17801780 if (llvm::is_contained (escapingValuedistTypes, Type{}))
17811781 return failure ();
1782- // Warp op can yield two types of values:
1783- // 1. Values that are not results of the forOp :
1784- // These values must also be yielded by the new warp op . Also, we need to
1785- // record the index mapping for these values to replace them later.
1786- // 2. Values that are results of the forOp :
1787- // In this case, we record the index mapping between the warp op result
1788- // index and matching forOp result index.
1782+ // `WarpOp` can yield two types of values:
1783+ // 1. Values that are not results of the `ForOp` :
1784+ // These values must also be yielded by the new `WarpOp` . Also, we need
1785+ // to record the index mapping for these values to replace them later.
1786+ // 2. Values that are results of the `ForOp` :
1787+ // In this case, we record the index mapping between the `WarpOp` result
1788+ // index and matching `ForOp` result index.
17891789 SmallVector<Value> nonForYieldedValues;
17901790 SmallVector<unsigned > nonForResultIndices;
1791- DenseMap <unsigned , unsigned > forResultMapping;
1791+ llvm::SmallDenseMap <unsigned , unsigned > forResultMapping;
17921792 for (OpOperand &yieldOperand : newWarpOpYield->getOpOperands ()) {
17931793 // Yielded value is not a result of the forOp.
17941794 if (yieldOperand.get ().getDefiningOp () != forOp.getOperation ()) {
@@ -1801,10 +1801,10 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18011801 forResult.getResultNumber ();
18021802 }
18031803
1804- // Newly created warp op will yield values in following order:
1805- // 1. All init args of the forOp .
1804+ // Newly created `WarpOp` will yield values in following order:
1805+ // 1. All init args of the `ForOp` .
18061806 // 2. All escaping values.
1807- // 3. All non-for yielded values.
1807+ // 3. All non-`ForOp` yielded values.
18081808 SmallVector<Value> newWarpOpYieldValues;
18091809 SmallVector<Type> newWarpOpDistTypes;
18101810 for (auto [i, initArg] : llvm::enumerate (forOp.getInitArgs ())) {
@@ -1823,50 +1823,50 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18231823 newWarpOpDistTypes.insert (newWarpOpDistTypes.end (),
18241824 escapingValuedistTypes.begin (),
18251825 escapingValuedistTypes.end ());
1826- // Next, we insert all non-for yielded values and their distributed types.
1827- // We also create a mapping between the non-for yielded value index and the
1828- // corresponding new warp op yield value index (needed to update users
1829- // later).
1830- DenseMap <unsigned , unsigned > warpResultMapping;
1826+ // Next, we insert all non-`ForOp` yielded values and their distributed
1827+ // types. We also create a mapping between the non-`ForOp` yielded value
1828+ // index and the corresponding new `WarpOp` yield value index (needed to
1829+ // update users later).
1830+ llvm::SmallDenseMap <unsigned , unsigned > warpResultMapping;
18311831 for (auto [i, v] : llvm::enumerate (nonForYieldedValues)) {
18321832 warpResultMapping[nonForResultIndices[i]] = newWarpOpYieldValues.size ();
18331833 newWarpOpYieldValues.push_back (v);
18341834 newWarpOpDistTypes.push_back (
18351835 warpOp.getResult (nonForResultIndices[i]).getType ());
18361836 }
1837- // Create the new warp op with the updated yield values and types.
1837+ // Create the new `WarpOp` with the updated yield values and types.
18381838 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns (
18391839 rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
18401840 newWarpOpYield = cast<gpu::YieldOp>(
18411841 newWarpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
18421842
1843- // Next, we create a new for op with the init args yielded by the new
1844- // warp op .
1845- unsigned escapingValuesStartIdx =
1846- forOp.getInitArgs ().size (); // ForOp init args are positioned before
1847- // escaping values in the new warp op .
1843+ // Next, we create a new `ForOp` with the init args yielded by the new
1844+ // `WarpOp` .
1845+ const unsigned escapingValuesStartIdx =
1846+ forOp.getInitArgs ().size (); // ` ForOp` init args are positioned before
1847+ // escaping values in the new `WarpOp` .
18481848 SmallVector<Value> newForOpOperands;
18491849 for (size_t i = 0 ; i < escapingValuesStartIdx; ++i)
18501850 newForOpOperands.push_back (newWarpOp.getResult (i));
18511851
1852- // Create a new for op outside the new warp op region.
1852+ // Create a new `ForOp` outside the new `WarpOp` region.
18531853 OpBuilder::InsertionGuard g (rewriter);
18541854 rewriter.setInsertionPointAfter (newWarpOp);
18551855 auto newForOp = rewriter.create <scf::ForOp>(
18561856 forOp.getLoc (), forOp.getLowerBound (), forOp.getUpperBound (),
18571857 forOp.getStep (), newForOpOperands);
1858- // Next, we insert a new warp op (called inner warp op ) inside the
1859- // newly created for op . This warp op will contain all ops that were
1860- // contained within the original for op body.
1858+ // Next, we insert a new `WarpOp` (called inner `WarpOp` ) inside the
1859+ // newly created `ForOp` . This `WarpOp` will contain all ops that were
1860+ // contained within the original `ForOp` body.
18611861 rewriter.setInsertionPointToStart (newForOp.getBody ());
18621862
18631863 SmallVector<Value> innerWarpInput (newForOp.getRegionIterArgs ().begin (),
18641864 newForOp.getRegionIterArgs ().end ());
18651865 SmallVector<Type> innerWarpInputType (forOp.getResultTypes ().begin (),
18661866 forOp.getResultTypes ().end ());
1867- // Escaping values are forwarded to the inner warp op as its (additional)
1867+ // Escaping values are forwarded to the inner `WarpOp` as its (additional)
18681868 // arguments. We keep track of the mapping between these values and their
1869- // argument index in the inner warp op (to replcace uses later).
1869+ // argument index in the inner `WarpOp` (to replace users later).
18701870 llvm::SmallDenseMap<Value, int64_t > argIndexMapping;
18711871 for (size_t i = escapingValuesStartIdx;
18721872 i < escapingValuesStartIdx + escapingValues.size (); ++i) {
@@ -1876,12 +1876,12 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18761876 innerWarpInputType.push_back (
18771877 escapingValueInputTypes[i - escapingValuesStartIdx]);
18781878 }
1879- // Create the inner warp op with the new input values and types.
1879+ // Create the inner `WarpOp` with the new input values and types.
18801880 auto innerWarp = rewriter.create <WarpExecuteOnLane0Op>(
18811881 newWarpOp.getLoc (), newForOp.getResultTypes (), newWarpOp.getLaneid (),
18821882 newWarpOp.getWarpSize (), innerWarpInput, innerWarpInputType);
18831883
1884- // Inline the for op body into the inner warp op body.
1884+ // Inline the `ForOp` body into the inner `WarpOp` body.
18851885 SmallVector<Value> argMapping;
18861886 argMapping.push_back (newForOp.getInductionVar ());
18871887 for (Value args : innerWarp.getBody ()->getArguments ())
@@ -1895,32 +1895,32 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18951895 rewriter.eraseOp (forOp.getBody ()->getTerminator ());
18961896 rewriter.mergeBlocks (forOp.getBody (), innerWarp.getBody (), argMapping);
18971897
1898- // Insert a gpu yieldOp at the end of the inner warp op body that yields
1899- // original forOp results.
1898+ // Insert a gpu `YieldOp` at the end of the inner `WarpOp` body that yields
1899+ // original `ForOp` results.
19001900 rewriter.setInsertionPointToEnd (innerWarp.getBody ());
19011901 rewriter.create <gpu::YieldOp>(innerWarp.getLoc (), yieldOperands);
19021902 rewriter.setInsertionPointAfter (innerWarp);
1903- // Insert a scf.yield op at the end of the new for op body that yields
1904- // the inner warp op results.
1903+ // Insert a scf.yield op at the end of the new `ForOp` body that yields
1904+ // the inner `WarpOp` results.
19051905 if (!innerWarp.getResults ().empty ())
19061906 rewriter.create <scf::YieldOp>(forOp.getLoc (), innerWarp.getResults ());
19071907
1908- // Update the users of original warp op results that were coming from the
1909- // original forOp to the corresponding new forOp result.
1908+ // Update the users of original `WarpOp` results that were coming from the
1909+ // original `ForOp` to the corresponding new `ForOp` result.
19101910 for (auto [origIdx, newIdx] : forResultMapping)
19111911 rewriter.replaceAllUsesExcept (warpOp.getResult (origIdx),
19121912 newForOp.getResult (newIdx), newForOp);
1913- // Similarly, update any users of the warp op results that were not
1914- // results of the forOp .
1913+ // Similarly, update any users of the `WarpOp` results that were not
1914+ // results of the `ForOp` .
19151915 for (auto [origIdx, newIdx] : warpResultMapping)
19161916 rewriter.replaceAllUsesWith (warpOp.getResult (origIdx),
19171917 newWarpOp.getResult (newIdx));
1918- // Remove the original warp op and for op , they should not have any uses
1918+ // Remove the original `WarpOp` and `ForOp` , they should not have any uses
19191919 // at this point.
19201920 rewriter.eraseOp (forOp);
19211921 rewriter.eraseOp (warpOp);
19221922 // Update any users of escaping values that were forwarded to the
1923- // inner warp op . These values are now arguments of the inner warp op .
1923+ // inner `WarpOp` . These values are now arguments of the inner `WarpOp` .
19241924 newForOp.walk ([&](Operation *op) {
19251925 for (OpOperand &operand : op->getOpOperands ()) {
19261926 auto it = argIndexMapping.find (operand.get ());
@@ -1930,7 +1930,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
19301930 }
19311931 });
19321932
1933- // Finally, hoist out any now uniform code from the inner warp op .
1933+ // Finally, hoist out any now uniform code from the inner `WarpOp` .
19341934 mlir::vector::moveScalarUniformCode (innerWarp);
19351935 return success ();
19361936 }
0 commit comments