@@ -1713,6 +1713,205 @@ struct WarpOpInsert : public WarpDistributionPattern {
17131713 }
17141714};
17151715
1716+ struct WarpOpScfIfOp : public WarpDistributionPattern {
1717+ WarpOpScfIfOp (MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1 )
1718+ : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
1719+ LogicalResult matchAndRewrite (WarpExecuteOnLane0Op warpOp,
1720+ PatternRewriter &rewriter) const override {
1721+ gpu::YieldOp warpOpYield = warpOp.getTerminator ();
1722+ // Only pick up `IfOp` if it is the last op in the region.
1723+ Operation *lastNode = warpOpYield->getPrevNode ();
1724+ auto ifOp = dyn_cast_or_null<scf::IfOp>(lastNode);
1725+ if (!ifOp)
1726+ return failure ();
1727+
1728+ // The current `WarpOp` can yield two types of values:
1729+ // 1. Not results of `IfOp`:
1730+ // Preserve them in the new `WarpOp`.
1731+ // Collect their yield index.
1732+ // 2. Results of `IfOp`:
1733+ // They are not part of the new `WarpOp` results.
1734+ // Map current warp's yield operand index to `IfOp` result idx.
1735+ SmallVector<Value> nonIfYieldValues;
1736+ SmallVector<unsigned > nonIfYieldIndices;
1737+ llvm::SmallDenseMap<unsigned , unsigned > ifResultMapping;
1738+ llvm::SmallDenseMap<unsigned , VectorType> ifResultDistTypes;
1739+ for (OpOperand &yieldOperand : warpOpYield->getOpOperands ()) {
1740+ const unsigned yieldOperandIdx = yieldOperand.getOperandNumber ();
1741+ if (yieldOperand.get ().getDefiningOp () != ifOp.getOperation ()) {
1742+ nonIfYieldValues.push_back (yieldOperand.get ());
1743+ nonIfYieldIndices.push_back (yieldOperandIdx);
1744+ continue ;
1745+ }
1746+ OpResult ifResult = cast<OpResult>(yieldOperand.get ());
1747+ const unsigned ifResultIdx = ifResult.getResultNumber ();
1748+ ifResultMapping[yieldOperandIdx] = ifResultIdx;
1749+ // If this `ifOp` result is vector type and it is yielded by the
1750+ // `WarpOp`, we keep track the distributed type for this result.
1751+ if (!isa<VectorType>(ifResult.getType ()))
1752+ continue ;
1753+ VectorType distType =
1754+ cast<VectorType>(warpOp.getResult (yieldOperandIdx).getType ());
1755+ ifResultDistTypes[ifResultIdx] = distType;
1756+ }
1757+
1758+ // Collect `WarpOp`-defined values used in `ifOp`, the new warp op returns
1759+ // them
1760+ auto getEscapingValues = [&](Region &branch,
1761+ llvm::SmallSetVector<Value, 32 > &values,
1762+ SmallVector<Type> &inputTypes,
1763+ SmallVector<Type> &distTypes) {
1764+ if (branch.empty ())
1765+ return ;
1766+ mlir::visitUsedValuesDefinedAbove (branch, [&](OpOperand *operand) {
1767+ Operation *parent = operand->get ().getParentRegion ()->getParentOp ();
1768+ if (warpOp->isAncestor (parent)) {
1769+ if (!values.insert (operand->get ()))
1770+ return ;
1771+ Type distType = operand->get ().getType ();
1772+ if (auto vecType = dyn_cast<VectorType>(distType)) {
1773+ AffineMap map = distributionMapFn (operand->get ());
1774+ distType = getDistributedType (vecType, map, warpOp.getWarpSize ());
1775+ }
1776+ inputTypes.push_back (operand->get ().getType ());
1777+ distTypes.push_back (distType);
1778+ }
1779+ });
1780+ };
1781+ llvm::SmallSetVector<Value, 32 > escapingValuesThen;
1782+ SmallVector<Type> escapingValueInputTypesThen; // inner warp op block args
1783+ SmallVector<Type> escapingValueDistTypesThen; // new warp returns
1784+ getEscapingValues (ifOp.getThenRegion (), escapingValuesThen,
1785+ escapingValueInputTypesThen, escapingValueDistTypesThen);
1786+ llvm::SmallSetVector<Value, 32 > escapingValuesElse;
1787+ SmallVector<Type> escapingValueInputTypesElse; // inner warp op block args
1788+ SmallVector<Type> escapingValueDistTypesElse; // new warp returns
1789+ getEscapingValues (ifOp.getElseRegion (), escapingValuesElse,
1790+ escapingValueInputTypesElse, escapingValueDistTypesElse);
1791+
1792+ if (llvm::is_contained (escapingValueDistTypesThen, Type{}) ||
1793+ llvm::is_contained (escapingValueDistTypesElse, Type{}))
1794+ return failure ();
1795+
1796+ // The new `WarpOp` groups yields values in following order:
1797+ // 1. Escaping values then branch
1798+ // 2. Escaping values else branch
1799+ // 3. All non-`ifOp` yielded values.
1800+ SmallVector<Value> newWarpOpYieldValues{escapingValuesThen.begin (),
1801+ escapingValuesThen.end ()};
1802+ newWarpOpYieldValues.append (escapingValuesElse.begin (),
1803+ escapingValuesElse.end ());
1804+ SmallVector<Type> newWarpOpDistTypes = escapingValueDistTypesThen;
1805+ newWarpOpDistTypes.append (escapingValueDistTypesElse.begin (),
1806+ escapingValueDistTypesElse.end ());
1807+
1808+ llvm::SmallDenseMap<unsigned , unsigned > origToNewYieldIdx;
1809+ for (auto [idx, val] :
1810+ llvm::zip_equal (nonIfYieldIndices, nonIfYieldValues)) {
1811+ origToNewYieldIdx[idx] = newWarpOpYieldValues.size ();
1812+ newWarpOpYieldValues.push_back (val);
1813+ newWarpOpDistTypes.push_back (warpOp.getResult (idx).getType ());
1814+ }
1815+ // Create the new `WarpOp` with the updated yield values and types.
1816+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns (
1817+ rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
1818+
1819+ // `ifOp` returns the result of the inner warp op.
1820+ SmallVector<Type> newIfOpDistResTypes;
1821+ for (auto [i, res] : llvm::enumerate (ifOp.getResults ())) {
1822+ Type distType = cast<Value>(res).getType ();
1823+ if (auto vecType = dyn_cast<VectorType>(distType)) {
1824+ AffineMap map = distributionMapFn (cast<Value>(res));
1825+ distType = ifResultDistTypes.count (i)
1826+ ? ifResultDistTypes[i]
1827+ : getDistributedType (vecType, map, warpOp.getWarpSize ());
1828+ }
1829+ newIfOpDistResTypes.push_back (distType);
1830+ }
1831+ // Create a new `IfOp` outside the new `WarpOp` region.
1832+ OpBuilder::InsertionGuard g (rewriter);
1833+ rewriter.setInsertionPointAfter (newWarpOp);
1834+ auto newIfOp = scf::IfOp::create (rewriter, ifOp.getLoc (),
1835+ newIfOpDistResTypes, ifOp.getCondition (),
1836+ static_cast <bool >(ifOp.thenBlock ()),
1837+ static_cast <bool >(ifOp.elseBlock ()));
1838+
1839+ auto processBranch = [&](Block *oldIfBranch, Block *newIfBranch,
1840+ llvm::SmallSetVector<Value, 32 > &escapingValues,
1841+ SmallVector<Type> &escapingValueInputTypes) {
1842+ OpBuilder::InsertionGuard g (rewriter);
1843+ if (!newIfBranch)
1844+ return ;
1845+ rewriter.setInsertionPointToStart (newIfBranch);
1846+ llvm::SmallDenseMap<Value, int64_t > escapeValToBlockArgIndex;
1847+ SmallVector<Value> innerWarpInputVals;
1848+ SmallVector<Type> innerWarpInputTypes;
1849+ for (size_t i = 0 ; i < escapingValues.size (); ++i) {
1850+ innerWarpInputVals.push_back (newWarpOp.getResult (i));
1851+ escapeValToBlockArgIndex[escapingValues[i]] =
1852+ innerWarpInputTypes.size ();
1853+ innerWarpInputTypes.push_back (escapingValueInputTypes[i]);
1854+ }
1855+ auto innerWarp = WarpExecuteOnLane0Op::create (
1856+ rewriter, newWarpOp.getLoc (), newIfOp.getResultTypes (),
1857+ newWarpOp.getLaneid (), newWarpOp.getWarpSize (), innerWarpInputVals,
1858+ innerWarpInputTypes);
1859+
1860+ innerWarp.getWarpRegion ().takeBody (*oldIfBranch->getParent ());
1861+ innerWarp.getWarpRegion ().addArguments (
1862+ innerWarpInputTypes,
1863+ SmallVector<Location>(innerWarpInputTypes.size (), ifOp.getLoc ()));
1864+
1865+ SmallVector<Value> yieldOperands;
1866+ for (Value operand : oldIfBranch->getTerminator ()->getOperands ())
1867+ yieldOperands.push_back (operand);
1868+ rewriter.eraseOp (oldIfBranch->getTerminator ());
1869+
1870+ rewriter.setInsertionPointToEnd (innerWarp.getBody ());
1871+ gpu::YieldOp::create (rewriter, innerWarp.getLoc (), yieldOperands);
1872+ rewriter.setInsertionPointAfter (innerWarp);
1873+ scf::YieldOp::create (rewriter, ifOp.getLoc (), innerWarp.getResults ());
1874+
1875+ // Update any users of escaping values that were forwarded to the
1876+ // inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
1877+ innerWarp.walk ([&](Operation *op) {
1878+ for (OpOperand &operand : op->getOpOperands ()) {
1879+ auto it = escapeValToBlockArgIndex.find (operand.get ());
1880+ if (it == escapeValToBlockArgIndex.end ())
1881+ continue ;
1882+ operand.set (innerWarp.getBodyRegion ().getArgument (it->second ));
1883+ }
1884+ });
1885+ mlir::vector::moveScalarUniformCode (innerWarp);
1886+ };
1887+ processBranch (&ifOp.getThenRegion ().front (),
1888+ &newIfOp.getThenRegion ().front (), escapingValuesThen,
1889+ escapingValueInputTypesThen);
1890+ if (!ifOp.getElseRegion ().empty ())
1891+ processBranch (&ifOp.getElseRegion ().front (),
1892+ &newIfOp.getElseRegion ().front (), escapingValuesElse,
1893+ escapingValueInputTypesElse);
1894+ // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
1895+ // result.
1896+ for (auto [origIdx, newIdx] : ifResultMapping)
1897+ rewriter.replaceAllUsesExcept (warpOp.getResult (origIdx),
1898+ newIfOp.getResult (newIdx), newIfOp);
1899+ // Similarly, update any users of the `WarpOp` results that were not
1900+ // results of the `IfOp`.
1901+ for (auto [origIdx, newIdx] : origToNewYieldIdx)
1902+ rewriter.replaceAllUsesWith (warpOp.getResult (origIdx),
1903+ newWarpOp.getResult (newIdx));
1904+ // Remove the original `WarpOp` and `IfOp`, they should not have any uses
1905+ // at this point.
1906+ rewriter.eraseOp (ifOp);
1907+ rewriter.eraseOp (warpOp);
1908+ return success ();
1909+ }
1910+
1911+ private:
1912+ DistributionMapFn distributionMapFn;
1913+ };
1914+
17161915// / Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
17171916// / the scf.ForOp is the last operation in the region so that it doesn't
17181917// / change the order of execution. This creates a new scf.for region after the
@@ -2068,6 +2267,8 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
20682267 benefit);
20692268 patterns.add <WarpOpScfForOp>(patterns.getContext (), distributionMapFn,
20702269 benefit);
2270+ patterns.add <WarpOpScfIfOp>(patterns.getContext (), distributionMapFn,
2271+ benefit);
20712272}
20722273
20732274void mlir::vector::populateDistributeReduction (
0 commit comments