@@ -1704,19 +1704,18 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17041704 : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
17051705 LogicalResult matchAndRewrite (WarpExecuteOnLane0Op warpOp,
17061706 PatternRewriter &rewriter) const override {
1707- auto yield = cast<gpu::YieldOp>(
1707+ auto warpOpYield = cast<gpu::YieldOp>(
17081708 warpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
1709- // Only pick up forOp if it is the last op in the region.
1710- Operation *lastNode = yield ->getPrevNode ();
1709+ // Only pick up `ForOp` if it is the last op in the region.
1710+ Operation *lastNode = warpOpYield ->getPrevNode ();
17111711 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
17121712 if (!forOp)
17131713 return failure ();
1714- // Collect Values that come from the warp op but are outside the forOp.
1715- // Those Value needs to be returned by the original warpOp and passed to
1716- // the new op.
1714+ // Collect Values that come from the `WarpOp` but are outside the `ForOp`.
1715+ // Those Values need to be returned by the new warp op.
17171716 llvm::SmallSetVector<Value, 32 > escapingValues;
1718- SmallVector<Type> inputTypes ;
1719- SmallVector<Type> distTypes ;
1717+ SmallVector<Type> escapingValueInputTypes ;
1718+ SmallVector<Type> escapingValueDistTypes ;
17201719 mlir::visitUsedValuesDefinedAbove (
17211720 forOp.getBodyRegion (), [&](OpOperand *operand) {
17221721 Operation *parent = operand->get ().getParentRegion ()->getParentOp ();
@@ -1728,81 +1727,168 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17281727 AffineMap map = distributionMapFn (operand->get ());
17291728 distType = getDistributedType (vecType, map, warpOp.getWarpSize ());
17301729 }
1731- inputTypes .push_back (operand->get ().getType ());
1732- distTypes .push_back (distType);
1730+ escapingValueInputTypes .push_back (operand->get ().getType ());
1731+ escapingValueDistTypes .push_back (distType);
17331732 }
17341733 });
17351734
1736- if (llvm::is_contained (distTypes , Type{}))
1735+ if (llvm::is_contained (escapingValueDistTypes , Type{}))
17371736 return failure ();
1738-
1739- SmallVector<size_t > newRetIndices;
1740- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
1741- rewriter, warpOp, escapingValues.getArrayRef (), distTypes,
1742- newRetIndices);
1743- yield = cast<gpu::YieldOp>(
1744- newWarpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
1745-
1746- SmallVector<Value> newOperands;
1747- SmallVector<unsigned > resultIdx;
1748- // Collect all the outputs coming from the forOp.
1749- for (OpOperand &yieldOperand : yield->getOpOperands ()) {
1750- if (yieldOperand.get ().getDefiningOp () != forOp.getOperation ())
1737+ // `WarpOp` can yield two types of values:
1738+ // 1. Values that are not results of the `ForOp`:
1739+ // These values must also be yielded by the new `WarpOp`. Also, we need
1740+ // to record the index mapping for these values to replace them later.
1741+ // 2. Values that are results of the `ForOp`:
1742+ // In this case, we record the index mapping between the `WarpOp` result
1743+ // index and matching `ForOp` result index.
1744+ // Additionally, we keep track of the distributed types for all `ForOp`
1745+ // vector results.
1746+ SmallVector<Value> nonForYieldedValues;
1747+ SmallVector<unsigned > nonForResultIndices;
1748+ llvm::SmallDenseMap<unsigned , unsigned > forResultMapping;
1749+ llvm::SmallDenseMap<unsigned , VectorType> forResultDistTypes;
1750+ for (OpOperand &yieldOperand : warpOpYield->getOpOperands ()) {
1751+ // Yielded value is not a result of the forOp.
1752+ if (yieldOperand.get ().getDefiningOp () != forOp.getOperation ()) {
1753+ nonForYieldedValues.push_back (yieldOperand.get ());
1754+ nonForResultIndices.push_back (yieldOperand.getOperandNumber ());
17511755 continue ;
1752- auto forResult = cast<OpResult>(yieldOperand.get ());
1753- newOperands.push_back (
1754- newWarpOp.getResult (yieldOperand.getOperandNumber ()));
1755- yieldOperand.set (forOp.getInitArgs ()[forResult.getResultNumber ()]);
1756- resultIdx.push_back (yieldOperand.getOperandNumber ());
1756+ }
1757+ OpResult forResult = cast<OpResult>(yieldOperand.get ());
1758+ unsigned int forResultNumber = forResult.getResultNumber ();
1759+ forResultMapping[yieldOperand.getOperandNumber ()] = forResultNumber;
1760+ // If this `ForOp` result is vector type and it is yielded by the
1761+ // `WarpOp`, we keep track the distributed type for this result.
1762+ if (!isa<VectorType>(forResult.getType ()))
1763+ continue ;
1764+ VectorType distType = cast<VectorType>(
1765+ warpOp.getResult (yieldOperand.getOperandNumber ()).getType ());
1766+ forResultDistTypes[forResultNumber] = distType;
17571767 }
17581768
1769+ // Newly created `WarpOp` will yield values in following order:
1770+ // 1. All init args of the `ForOp`.
1771+ // 2. All escaping values.
1772+ // 3. All non-`ForOp` yielded values.
1773+ SmallVector<Value> newWarpOpYieldValues;
1774+ SmallVector<Type> newWarpOpDistTypes;
1775+ for (auto [i, initArg] : llvm::enumerate (forOp.getInitArgs ())) {
1776+ newWarpOpYieldValues.push_back (initArg);
1777+ // Compute the distributed type for this init arg.
1778+ Type distType = initArg.getType ();
1779+ if (auto vecType = dyn_cast<VectorType>(distType)) {
1780+ // If the `ForOp` result corresponds to this init arg is already yielded
1781+ // we can get the distributed type from `forResultDistTypes` map.
1782+ // Otherwise, we compute it using distributionMapFn.
1783+ AffineMap map = distributionMapFn (initArg);
1784+ distType = forResultDistTypes.count (i)
1785+ ? forResultDistTypes[i]
1786+ : getDistributedType (vecType, map, warpOp.getWarpSize ());
1787+ }
1788+ newWarpOpDistTypes.push_back (distType);
1789+ }
1790+ // Insert escaping values and their distributed types.
1791+ newWarpOpYieldValues.insert (newWarpOpYieldValues.end (),
1792+ escapingValues.begin (), escapingValues.end ());
1793+ newWarpOpDistTypes.insert (newWarpOpDistTypes.end (),
1794+ escapingValueDistTypes.begin (),
1795+ escapingValueDistTypes.end ());
1796+ // Next, we insert all non-`ForOp` yielded values and their distributed
1797+ // types. We also create a mapping between the non-`ForOp` yielded value
1798+ // index and the corresponding new `WarpOp` yield value index (needed to
1799+ // update users later).
1800+ llvm::SmallDenseMap<unsigned , unsigned > nonForResultMapping;
1801+ for (auto [i, v] :
1802+ llvm::zip_equal (nonForResultIndices, nonForYieldedValues)) {
1803+ nonForResultMapping[i] = newWarpOpYieldValues.size ();
1804+ newWarpOpYieldValues.push_back (v);
1805+ newWarpOpDistTypes.push_back (warpOp.getResult (i).getType ());
1806+ }
1807+ // Create the new `WarpOp` with the updated yield values and types.
1808+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns (
1809+ rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
1810+
1811+ // Next, we create a new `ForOp` with the init args yielded by the new
1812+ // `WarpOp`.
1813+ const unsigned escapingValuesStartIdx =
1814+ forOp.getInitArgs ().size (); // `ForOp` init args are positioned before
1815+ // escaping values in the new `WarpOp`.
1816+ SmallVector<Value> newForOpOperands;
1817+ for (size_t i = 0 ; i < escapingValuesStartIdx; ++i)
1818+ newForOpOperands.push_back (newWarpOp.getResult (i));
1819+
1820+ // Create a new `ForOp` outside the new `WarpOp` region.
17591821 OpBuilder::InsertionGuard g (rewriter);
17601822 rewriter.setInsertionPointAfter (newWarpOp);
1761-
1762- // Create a new for op outside the region with a WarpExecuteOnLane0Op
1763- // region inside.
17641823 auto newForOp = rewriter.create <scf::ForOp>(
17651824 forOp.getLoc (), forOp.getLowerBound (), forOp.getUpperBound (),
1766- forOp.getStep (), newOperands);
1825+ forOp.getStep (), newForOpOperands);
1826+ // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
1827+ // newly created `ForOp`. This `WarpOp` will contain all ops that were
1828+ // contained within the original `ForOp` body.
17671829 rewriter.setInsertionPointToStart (newForOp.getBody ());
17681830
1769- SmallVector<Value> warpInput (newForOp.getRegionIterArgs ().begin (),
1770- newForOp.getRegionIterArgs ().end ());
1771- SmallVector<Type> warpInputType (forOp.getResultTypes ().begin (),
1772- forOp.getResultTypes ().end ());
1831+ SmallVector<Value> innerWarpInput (newForOp.getRegionIterArgs ().begin (),
1832+ newForOp.getRegionIterArgs ().end ());
1833+ SmallVector<Type> innerWarpInputType (forOp.getResultTypes ().begin (),
1834+ forOp.getResultTypes ().end ());
1835+ // Escaping values are forwarded to the inner `WarpOp` as its (additional)
1836+ // arguments. We keep track of the mapping between these values and their
1837+ // argument index in the inner `WarpOp` (to replace users later).
17731838 llvm::SmallDenseMap<Value, int64_t > argIndexMapping;
1774- for (auto [i, retIdx] : llvm::enumerate (newRetIndices)) {
1775- warpInput.push_back (newWarpOp.getResult (retIdx));
1776- argIndexMapping[escapingValues[i]] = warpInputType.size ();
1777- warpInputType.push_back (inputTypes[i]);
1839+ for (size_t i = escapingValuesStartIdx;
1840+ i < escapingValuesStartIdx + escapingValues.size (); ++i) {
1841+ innerWarpInput.push_back (newWarpOp.getResult (i));
1842+ argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
1843+ innerWarpInputType.size ();
1844+ innerWarpInputType.push_back (
1845+ escapingValueInputTypes[i - escapingValuesStartIdx]);
17781846 }
1847+ // Create the inner `WarpOp` with the new input values and types.
17791848 auto innerWarp = rewriter.create <WarpExecuteOnLane0Op>(
17801849 newWarpOp.getLoc (), newForOp.getResultTypes (), newWarpOp.getLaneid (),
1781- newWarpOp.getWarpSize (), warpInput, warpInputType );
1850+ newWarpOp.getWarpSize (), innerWarpInput, innerWarpInputType );
17821851
1852+ // Inline the `ForOp` body into the inner `WarpOp` body.
17831853 SmallVector<Value> argMapping;
17841854 argMapping.push_back (newForOp.getInductionVar ());
1785- for (Value args : innerWarp.getBody ()->getArguments ()) {
1855+ for (Value args : innerWarp.getBody ()->getArguments ())
17861856 argMapping.push_back (args);
1787- }
1857+
17881858 argMapping.resize (forOp.getBody ()->getNumArguments ());
17891859 SmallVector<Value> yieldOperands;
17901860 for (Value operand : forOp.getBody ()->getTerminator ()->getOperands ())
17911861 yieldOperands.push_back (operand);
1862+
17921863 rewriter.eraseOp (forOp.getBody ()->getTerminator ());
17931864 rewriter.mergeBlocks (forOp.getBody (), innerWarp.getBody (), argMapping);
1865+
1866+ // Insert a gpu `YieldOp` at the end of the inner `WarpOp` body that yields
1867+ // original `ForOp` results.
17941868 rewriter.setInsertionPointToEnd (innerWarp.getBody ());
17951869 rewriter.create <gpu::YieldOp>(innerWarp.getLoc (), yieldOperands);
17961870 rewriter.setInsertionPointAfter (innerWarp);
1871+ // Insert a scf.yield op at the end of the new `ForOp` body that yields
1872+ // the inner `WarpOp` results.
17971873 if (!innerWarp.getResults ().empty ())
17981874 rewriter.create <scf::YieldOp>(forOp.getLoc (), innerWarp.getResults ());
1875+
1876+ // Update the users of original `WarpOp` results that were coming from the
1877+ // original `ForOp` to the corresponding new `ForOp` result.
1878+ for (auto [origIdx, newIdx] : forResultMapping)
1879+ rewriter.replaceAllUsesExcept (warpOp.getResult (origIdx),
1880+ newForOp.getResult (newIdx), newForOp);
1881+ // Similarly, update any users of the `WarpOp` results that were not
1882+ // results of the `ForOp`.
1883+ for (auto [origIdx, newIdx] : nonForResultMapping)
1884+ rewriter.replaceAllUsesWith (warpOp.getResult (origIdx),
1885+ newWarpOp.getResult (newIdx));
1886+ // Remove the original `WarpOp` and `ForOp`, they should not have any uses
1887+ // at this point.
17991888 rewriter.eraseOp (forOp);
1800- // Replace the warpOp result coming from the original ForOp.
1801- for (const auto &res : llvm::enumerate (resultIdx)) {
1802- rewriter.replaceAllUsesWith (newWarpOp.getResult (res.value ()),
1803- newForOp.getResult (res.index ()));
1804- newForOp->setOperand (res.index () + 3 , newWarpOp.getResult (res.value ()));
1805- }
1889+ rewriter.eraseOp (warpOp);
1890+ // Update any users of escaping values that were forwarded to the
1891+ // inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
18061892 newForOp.walk ([&](Operation *op) {
18071893 for (OpOperand &operand : op->getOpOperands ()) {
18081894 auto it = argIndexMapping.find (operand.get ());
@@ -1812,7 +1898,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18121898 }
18131899 });
18141900
1815- // Finally, hoist out any now uniform code from the inner warp op .
1901+ // Finally, hoist out any now uniform code from the inner `WarpOp` .
18161902 mlir::vector::moveScalarUniformCode (innerWarp);
18171903 return success ();
18181904 }
0 commit comments