@@ -371,6 +371,36 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
371371 return targetType;
372372}
373373
374+ // / Given a warpOp that contains ops with regions, the corresponding op's
375+ // / "inner" region and the distributionMapFn, get all values used by the op's
376+ // / region that are defined within the warpOp. Return the set of values, their
377+ // / types and their distributed types.
378+ std::tuple<llvm::SmallSetVector<Value, 32 >, SmallVector<Type>,
379+ SmallVector<Type>>
380+ getInnerRegionEscapingValues (WarpExecuteOnLane0Op warpOp, Region &innerRegion,
381+ DistributionMapFn distributionMapFn) {
382+ llvm::SmallSetVector<Value, 32 > escapingValues;
383+ SmallVector<Type> escapingValueTypes;
384+ SmallVector<Type> escapingValueDistTypes; // to yield from the new warpOp
385+ if (innerRegion.empty ())
386+ return {escapingValues, escapingValueTypes, escapingValueDistTypes};
387+ mlir::visitUsedValuesDefinedAbove (innerRegion, [&](OpOperand *operand) {
388+ Operation *parent = operand->get ().getParentRegion ()->getParentOp ();
389+ if (warpOp->isAncestor (parent)) {
390+ if (!escapingValues.insert (operand->get ()))
391+ return ;
392+ Type distType = operand->get ().getType ();
393+ if (auto vecType = dyn_cast<VectorType>(distType)) {
394+ AffineMap map = distributionMapFn (operand->get ());
395+ distType = getDistributedType (vecType, map, warpOp.getWarpSize ());
396+ }
397+ escapingValueTypes.push_back (operand->get ().getType ());
398+ escapingValueDistTypes.push_back (distType);
399+ }
400+ });
401+ return {escapingValues, escapingValueTypes, escapingValueDistTypes};
402+ }
403+
374404// / Distribute transfer_write ops based on the affine map returned by
375405// / `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
376406// / will not be distributed (it should be less than the warp size).
@@ -1713,6 +1743,32 @@ struct WarpOpInsert : public WarpDistributionPattern {
17131743 }
17141744};
17151745
1746+ // / Sink scf.if out of WarpExecuteOnLane0Op. This can be done only if
1747+ // / the scf.if is the last operation in the region so that it doesn't
1748+ // / change the order of execution. This creates a new scf.if after the
1749+ // / WarpExecuteOnLane0Op. Each branch of the new scf.if is enclosed in
1750+ // / the "inner" WarpExecuteOnLane0Op. Example:
1751+ // / ```
1752+ // / gpu.warp_execute_on_lane_0(%laneid)[32] {
1753+ // / %payload = ... : vector<32xindex>
1754+ // / scf.if %pred {
1755+ // / vector.store %payload, %buffer[%idx] : memref<128xindex>,
1756+ // / vector<32xindex>
1757+ // / }
1758+ // / gpu.yield
1759+ // / }
1760+ // / ```
1761+ // / %r = gpu.warp_execute_on_lane_0(%laneid)[32] {
1762+ // / %payload = ... : vector<32xindex>
1763+ // / gpu.yield %payload : vector<32xindex>
1764+ // / }
1765+ // / scf.if %pred {
1766+ // / gpu.warp_execute_on_lane_0(%laneid)[32] args(%r : vector<1xindex>) {
1767+ // / ^bb0(%arg1: vector<32xindex>):
1768+ // / vector.store %arg1, %buffer[%idx] : memref<128xindex>, vector<32xindex>
1769+ // / }
1770+ // / }
1771+ // / ```
17161772struct WarpOpScfIfOp : public WarpDistributionPattern {
17171773 WarpOpScfIfOp (MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1 )
17181774 : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
@@ -1728,7 +1784,7 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
17281784 // The current `WarpOp` can yield two types of values:
17291785 // 1. Not results of `IfOp`:
17301786 // Preserve them in the new `WarpOp`.
1731- // Collect their yield index.
1787+ // Collect their yield index to remap the usages .
17321788 // 2. Results of `IfOp`:
17331789 // They are not part of the new `WarpOp` results.
17341790 // Map current warp's yield operand index to `IfOp` result idx.
@@ -1757,38 +1813,14 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
17571813
17581814 // Collect `WarpOp`-defined values used in `ifOp`, the new warp op returns
17591815 // them
1760- auto getEscapingValues = [&](Region &branch,
1761- llvm::SmallSetVector<Value, 32 > &values,
1762- SmallVector<Type> &inputTypes,
1763- SmallVector<Type> &distTypes) {
1764- if (branch.empty ())
1765- return ;
1766- mlir::visitUsedValuesDefinedAbove (branch, [&](OpOperand *operand) {
1767- Operation *parent = operand->get ().getParentRegion ()->getParentOp ();
1768- if (warpOp->isAncestor (parent)) {
1769- if (!values.insert (operand->get ()))
1770- return ;
1771- Type distType = operand->get ().getType ();
1772- if (auto vecType = dyn_cast<VectorType>(distType)) {
1773- AffineMap map = distributionMapFn (operand->get ());
1774- distType = getDistributedType (vecType, map, warpOp.getWarpSize ());
1775- }
1776- inputTypes.push_back (operand->get ().getType ());
1777- distTypes.push_back (distType);
1778- }
1779- });
1780- };
1781- llvm::SmallSetVector<Value, 32 > escapingValuesThen;
1782- SmallVector<Type> escapingValueInputTypesThen; // inner warp op block args
1783- SmallVector<Type> escapingValueDistTypesThen; // new warp returns
1784- getEscapingValues (ifOp.getThenRegion (), escapingValuesThen,
1785- escapingValueInputTypesThen, escapingValueDistTypesThen);
1786- llvm::SmallSetVector<Value, 32 > escapingValuesElse;
1787- SmallVector<Type> escapingValueInputTypesElse; // inner warp op block args
1788- SmallVector<Type> escapingValueDistTypesElse; // new warp returns
1789- getEscapingValues (ifOp.getElseRegion (), escapingValuesElse,
1790- escapingValueInputTypesElse, escapingValueDistTypesElse);
1791-
1816+ auto [escapingValuesThen, escapingValueInputTypesThen,
1817+ escapingValueDistTypesThen] =
1818+ getInnerRegionEscapingValues (warpOp, ifOp.getThenRegion (),
1819+ distributionMapFn);
1820+ auto [escapingValuesElse, escapingValueInputTypesElse,
1821+ escapingValueDistTypesElse] =
1822+ getInnerRegionEscapingValues (warpOp, ifOp.getElseRegion (),
1823+ distributionMapFn);
17921824 if (llvm::is_contained (escapingValueDistTypesThen, Type{}) ||
17931825 llvm::is_contained (escapingValueDistTypesElse, Type{}))
17941826 return failure ();
@@ -1825,6 +1857,7 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
18251857 Type distType = cast<Value>(res).getType ();
18261858 if (auto vecType = dyn_cast<VectorType>(distType)) {
18271859 AffineMap map = distributionMapFn (cast<Value>(res));
1860+ // Fallback to affine map if the dist result was not previously recorded
18281861 distType = ifResultDistTypes.count (i)
18291862 ? ifResultDistTypes[i]
18301863 : getDistributedType (vecType, map, warpOp.getWarpSize ());
@@ -1838,63 +1871,66 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
18381871 rewriter, ifOp.getLoc (), newIfOpDistResTypes, newWarpOp.getResult (0 ),
18391872 static_cast <bool >(ifOp.thenBlock ()),
18401873 static_cast <bool >(ifOp.elseBlock ()));
1841-
1842- auto processBranch = [&](Block *oldIfBranch, Block *newIfBranch,
1843- llvm::SmallSetVector<Value, 32 > &escapingValues,
1844- SmallVector<Type> &escapingValueInputTypes,
1845- size_t warpResRangeStart) {
1846- OpBuilder::InsertionGuard g (rewriter);
1847- if (!newIfBranch)
1848- return ;
1849- rewriter.setInsertionPointToStart (newIfBranch);
1850- llvm::SmallDenseMap<Value, int64_t > escapeValToBlockArgIndex;
1851- SmallVector<Value> innerWarpInputVals;
1852- SmallVector<Type> innerWarpInputTypes;
1853- for (size_t i = 0 ; i < escapingValues.size (); ++i, ++warpResRangeStart) {
1854- innerWarpInputVals.push_back (newWarpOp.getResult (warpResRangeStart));
1855- escapeValToBlockArgIndex[escapingValues[i]] =
1856- innerWarpInputTypes.size ();
1857- innerWarpInputTypes.push_back (escapingValueInputTypes[i]);
1858- }
1859- auto innerWarp = WarpExecuteOnLane0Op::create (
1860- rewriter, newWarpOp.getLoc (), newIfOp.getResultTypes (),
1861- newWarpOp.getLaneid (), newWarpOp.getWarpSize (), innerWarpInputVals,
1862- innerWarpInputTypes);
1863-
1864- innerWarp.getWarpRegion ().takeBody (*oldIfBranch->getParent ());
1865- innerWarp.getWarpRegion ().addArguments (
1866- innerWarpInputTypes,
1867- SmallVector<Location>(innerWarpInputTypes.size (), ifOp.getLoc ()));
1868-
1869- SmallVector<Value> yieldOperands;
1870- for (Value operand : oldIfBranch->getTerminator ()->getOperands ())
1871- yieldOperands.push_back (operand);
1872- rewriter.eraseOp (oldIfBranch->getTerminator ());
1873-
1874- rewriter.setInsertionPointToEnd (innerWarp.getBody ());
1875- gpu::YieldOp::create (rewriter, innerWarp.getLoc (), yieldOperands);
1876- rewriter.setInsertionPointAfter (innerWarp);
1877- scf::YieldOp::create (rewriter, ifOp.getLoc (), innerWarp.getResults ());
1878-
1879- // Update any users of escaping values that were forwarded to the
1880- // inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
1881- innerWarp.walk ([&](Operation *op) {
1882- for (OpOperand &operand : op->getOpOperands ()) {
1883- auto it = escapeValToBlockArgIndex.find (operand.get ());
1884- if (it == escapeValToBlockArgIndex.end ())
1885- continue ;
1886- operand.set (innerWarp.getBodyRegion ().getArgument (it->second ));
1887- }
1888- });
1889- mlir::vector::moveScalarUniformCode (innerWarp);
1890- };
1891- processBranch (&ifOp.getThenRegion ().front (),
1892- &newIfOp.getThenRegion ().front (), escapingValuesThen,
1893- escapingValueInputTypesThen, 1 );
1874+ auto encloseRegionInWarpOp =
1875+ [&](Block *oldIfBranch, Block *newIfBranch,
1876+ llvm::SmallSetVector<Value, 32 > &escapingValues,
1877+ SmallVector<Type> &escapingValueInputTypes,
1878+ size_t warpResRangeStart) {
1879+ OpBuilder::InsertionGuard g (rewriter);
1880+ if (!newIfBranch)
1881+ return ;
1882+ rewriter.setInsertionPointToStart (newIfBranch);
1883+ llvm::SmallDenseMap<Value, int64_t > escapeValToBlockArgIndex;
1884+ SmallVector<Value> innerWarpInputVals;
1885+ SmallVector<Type> innerWarpInputTypes;
1886+ for (size_t i = 0 ; i < escapingValues.size ();
1887+ ++i, ++warpResRangeStart) {
1888+ innerWarpInputVals.push_back (
1889+ newWarpOp.getResult (warpResRangeStart));
1890+ escapeValToBlockArgIndex[escapingValues[i]] =
1891+ innerWarpInputTypes.size ();
1892+ innerWarpInputTypes.push_back (escapingValueInputTypes[i]);
1893+ }
1894+ auto innerWarp = WarpExecuteOnLane0Op::create (
1895+ rewriter, newWarpOp.getLoc (), newIfOp.getResultTypes (),
1896+ newWarpOp.getLaneid (), newWarpOp.getWarpSize (),
1897+ innerWarpInputVals, innerWarpInputTypes);
1898+
1899+ innerWarp.getWarpRegion ().takeBody (*oldIfBranch->getParent ());
1900+ innerWarp.getWarpRegion ().addArguments (
1901+ innerWarpInputTypes,
1902+ SmallVector<Location>(innerWarpInputTypes.size (), ifOp.getLoc ()));
1903+
1904+ SmallVector<Value> yieldOperands;
1905+ for (Value operand : oldIfBranch->getTerminator ()->getOperands ())
1906+ yieldOperands.push_back (operand);
1907+ rewriter.eraseOp (oldIfBranch->getTerminator ());
1908+
1909+ rewriter.setInsertionPointToEnd (innerWarp.getBody ());
1910+ gpu::YieldOp::create (rewriter, innerWarp.getLoc (), yieldOperands);
1911+ rewriter.setInsertionPointAfter (innerWarp);
1912+ scf::YieldOp::create (rewriter, ifOp.getLoc (), innerWarp.getResults ());
1913+
1914+ // Update any users of escaping values that were forwarded to the
1915+ // inner `WarpOp`. These values are arguments of the inner `WarpOp`.
1916+ innerWarp.walk ([&](Operation *op) {
1917+ for (OpOperand &operand : op->getOpOperands ()) {
1918+ auto it = escapeValToBlockArgIndex.find (operand.get ());
1919+ if (it == escapeValToBlockArgIndex.end ())
1920+ continue ;
1921+ operand.set (innerWarp.getBodyRegion ().getArgument (it->second ));
1922+ }
1923+ });
1924+ mlir::vector::moveScalarUniformCode (innerWarp);
1925+ };
1926+ encloseRegionInWarpOp (&ifOp.getThenRegion ().front (),
1927+ &newIfOp.getThenRegion ().front (), escapingValuesThen,
1928+ escapingValueInputTypesThen, 1 );
18941929 if (!ifOp.getElseRegion ().empty ())
1895- processBranch (&ifOp.getElseRegion ().front (),
1896- &newIfOp.getElseRegion ().front (), escapingValuesElse,
1897- escapingValueInputTypesElse, 1 + escapingValuesThen.size ());
1930+ encloseRegionInWarpOp (&ifOp.getElseRegion ().front (),
1931+ &newIfOp.getElseRegion ().front (),
1932+ escapingValuesElse, escapingValueInputTypesElse,
1933+ 1 + escapingValuesThen.size ());
18981934 // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
18991935 // result.
19001936 for (auto [origIdx, newIdx] : ifResultMapping)
0 commit comments