@@ -1749,19 +1749,18 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17491749 : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
17501750 LogicalResult matchAndRewrite (WarpExecuteOnLane0Op warpOp,
17511751 PatternRewriter &rewriter) const override {
1752- auto yield = cast<gpu::YieldOp>(
1752+ auto newWarpOpYield = cast<gpu::YieldOp>(
17531753 warpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
17541754 // Only pick up forOp if it is the last op in the region.
1755- Operation *lastNode = yield ->getPrevNode ();
1755+ Operation *lastNode = newWarpOpYield ->getPrevNode ();
17561756 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
17571757 if (!forOp)
17581758 return failure ();
17591759 // Collect Values that come from the warp op but are outside the forOp.
1760- // Those Value needs to be returned by the original warpOp and passed to
1761- // the new op.
1760+ // Those Value needs to be returned by the new warp op.
17621761 llvm::SmallSetVector<Value, 32 > escapingValues;
1763- SmallVector<Type> inputTypes ;
1764- SmallVector<Type> distTypes ;
1762+ SmallVector<Type> escapingValueInputTypes ;
1763+ SmallVector<Type> escapingValuedistTypes ;
17651764 mlir::visitUsedValuesDefinedAbove (
17661765 forOp.getBodyRegion (), [&](OpOperand *operand) {
17671766 Operation *parent = operand->get ().getParentRegion ()->getParentOp ();
@@ -1773,183 +1772,155 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17731772 AffineMap map = distributionMapFn (operand->get ());
17741773 distType = getDistributedType (vecType, map, warpOp.getWarpSize ());
17751774 }
1776- inputTypes .push_back (operand->get ().getType ());
1777- distTypes .push_back (distType);
1775+ escapingValueInputTypes .push_back (operand->get ().getType ());
1776+ escapingValuedistTypes .push_back (distType);
17781777 }
17791778 });
17801779
1781- if (llvm::is_contained (distTypes , Type{}))
1780+ if (llvm::is_contained (escapingValuedistTypes , Type{}))
17821781 return failure ();
1783-
1782+ // Warp op can yield two types of values:
1783+ // 1. Values that are not results of the forOp:
1784+ // These values must also be yielded by the new warp op. Also, we need to
1785+ // record the index mapping for these values to replace them later.
1786+ // 2. Values that are results of the forOp:
1787+ // In this case, we record the index mapping between the warp op result
1788+ // index and matching forOp result index.
17841789 SmallVector<Value> nonForYieldedValues;
1785- // SmallVector<Type> nonForYieldedTypes;
17861790 SmallVector<unsigned > nonForResultIndices;
1787-
1788- // record result mapping.
17891791 DenseMap<unsigned , unsigned > forResultMapping;
1790- DenseMap<unsigned , unsigned > warpResultMapping;
1791- // llvm::SmallDenseMap<unsigned, unsigned> forResultToWarpResultMapping;
1792- for (OpOperand &yieldOperand : yield->getOpOperands ()) {
1792+ for (OpOperand &yieldOperand : newWarpOpYield->getOpOperands ()) {
1793+ // Yielded value is not a result of the forOp.
17931794 if (yieldOperand.get ().getDefiningOp () != forOp.getOperation ()) {
17941795 nonForYieldedValues.push_back (yieldOperand.get ());
1795- // nonForYieldedTypes.push_back(
1796- // warpOp.getResult(yieldOperand.getOperandNumber()).getType());
17971796 nonForResultIndices.push_back (yieldOperand.getOperandNumber ());
17981797 continue ;
17991798 }
18001799 OpResult forResult = cast<OpResult>(yieldOperand.get ());
18011800 forResultMapping[yieldOperand.getOperandNumber ()] =
18021801 forResult.getResultNumber ();
1803- // forResultToWarpResultMapping[forResult.getResultNumber()] =
1804- // yieldOperand.getOperandNumber();
1805- // yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
18061802 }
18071803
1808- // llvm::errs() << "non for yielded values size: "
1809- // << nonForYieldedValues.size() << "\n";
1810-
1811- // llvm::errs() << "escpaing values size: " << escapingValues.size() <<
1812- // "\n";
1813- SmallVector<Value> yieldedValuesFromWarpOp;
1814- SmallVector<Type> yieldedTypesFromWarpOp;
1815- // All init args of the forOp are yielded from the original warp op.
1804+ // Newly created warp op will yield values in following order:
1805+ // 1. All init args of the forOp.
1806+ // 2. All escaping values.
1807+ // 3. All non-for yielded values.
1808+ SmallVector<Value> newWarpOpYieldValues;
1809+ SmallVector<Type> newWarpOpDistTypes;
18161810 for (auto [i, initArg] : llvm::enumerate (forOp.getInitArgs ())) {
1817- yieldedValuesFromWarpOp .push_back (initArg);
1818- // find distributed type for the init arg.
1811+ newWarpOpYieldValues .push_back (initArg);
1812+ // Compute the distributed type for this init arg.
18191813 Type distType = initArg.getType ();
18201814 if (auto vecType = dyn_cast<VectorType>(distType)) {
1821- // if (forResultToWarpResultMapping.contains(i)) {
1822- // // If the init arg is yielded from the warp op, we need to compute
1823- // the
1824- // // distributed type.
1825- // distType =
1826- // warpOp.getResult(forResultToWarpResultMapping[i]).getType();
1827- // } else {
18281815 AffineMap map = distributionMapFn (initArg);
18291816 distType = getDistributedType (vecType, map, warpOp.getWarpSize ());
1830- // }
18311817 }
1832- // llvm::errs() << "distributed type: " << distType << "\n";
1833- yieldedTypesFromWarpOp.push_back (distType);
1818+ newWarpOpDistTypes.push_back (distType);
18341819 }
1835- // All escaping values are yielded from the original warp op.
1836- yieldedValuesFromWarpOp.insert (yieldedValuesFromWarpOp.end (),
1837- escapingValues.begin (),
1838- escapingValues.end ());
1839- yieldedTypesFromWarpOp.insert (yieldedTypesFromWarpOp.end (),
1840- distTypes.begin (), distTypes.end ());
1841-
1820+ // Insert escaping values and their distributed types.
1821+ newWarpOpYieldValues.insert (newWarpOpYieldValues.end (),
1822+ escapingValues.begin (), escapingValues.end ());
1823+ newWarpOpDistTypes.insert (newWarpOpDistTypes.end (),
1824+ escapingValuedistTypes.begin (),
1825+ escapingValuedistTypes.end ());
1826+ // Next, we insert all non-for yielded values and their distributed types.
1827+ // We also create a mapping between the non-for yielded value index and the
1828+ // corresponding new warp op yield value index (needed to update users
1829+ // later).
1830+ DenseMap<unsigned , unsigned > warpResultMapping;
18421831 for (auto [i, v] : llvm::enumerate (nonForYieldedValues)) {
1843- warpResultMapping[nonForResultIndices[i]] =
1844- yieldedValuesFromWarpOp.size ();
1845- yieldedValuesFromWarpOp.push_back (v);
1846- yieldedTypesFromWarpOp.push_back (
1832+ warpResultMapping[nonForResultIndices[i]] = newWarpOpYieldValues.size ();
1833+ newWarpOpYieldValues.push_back (v);
1834+ newWarpOpDistTypes.push_back (
18471835 warpOp.getResult (nonForResultIndices[i]).getType ());
18481836 }
1849-
1850- // SmallVector<size_t> newRetIndices;
1837+ // Create the new warp op with the updated yield values and types.
18511838 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns (
1852- rewriter, warpOp, yieldedValuesFromWarpOp, yieldedTypesFromWarpOp );
1853- yield = cast<gpu::YieldOp>(
1839+ rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes );
1840+ newWarpOpYield = cast<gpu::YieldOp>(
18541841 newWarpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
18551842
1856- // newWarpOp->print(llvm::outs());
1857- // llvm::outs() << "\n";
1858-
1859- SmallVector<Value> newOperands;
1860- // Collect the new init args coming from the new warp op.
1861- for (size_t i = 0 ; i < forOp.getInitArgs ().size (); ++i)
1862- newOperands.push_back (newWarpOp.getResult (i));
1863- // for (OpOperand &yieldOperand : yield->getOpOperands()) {
1864- // if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
1865- // continue;
1866- // OpResult forResult = cast<OpResult>(yieldOperand.get());
1867- // resultIdx.push_back(forResult.getResultNumber());
1868- // yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
1869- // }
1843+ // Next, we create a new for op with the init args yielded by the new
1844+ // warp op.
1845+ unsigned escapingValuesStartIdx =
1846+ forOp.getInitArgs ().size (); // ForOp init args are positioned before
1847+ // escaping values in the new warp op.
1848+ SmallVector<Value> newForOpOperands;
1849+ for (size_t i = 0 ; i < escapingValuesStartIdx; ++i)
1850+ newForOpOperands.push_back (newWarpOp.getResult (i));
18701851
1852+ // Create a new for op outside the new warp op region.
18711853 OpBuilder::InsertionGuard g (rewriter);
18721854 rewriter.setInsertionPointAfter (newWarpOp);
1873-
1874- // Create a new for op outside the region with a WarpExecuteOnLane0Op
1875- // region inside.
18761855 auto newForOp = rewriter.create <scf::ForOp>(
18771856 forOp.getLoc (), forOp.getLowerBound (), forOp.getUpperBound (),
1878- forOp.getStep (), newOperands);
1857+ forOp.getStep (), newForOpOperands);
1858+ // Next, we insert a new warp op (called inner warp op) inside the
1859+ // newly created for op. This warp op will contain all ops that were
1860+ // contained within the original for op body.
18791861 rewriter.setInsertionPointToStart (newForOp.getBody ());
18801862
1881- SmallVector<Value> warpInput (newForOp.getRegionIterArgs ().begin (),
1882- newForOp.getRegionIterArgs ().end ());
1883- SmallVector<Type> warpInputType (forOp.getResultTypes ().begin (),
1884- forOp.getResultTypes ().end ());
1863+ SmallVector<Value> innerWarpInput (newForOp.getRegionIterArgs ().begin (),
1864+ newForOp.getRegionIterArgs ().end ());
1865+ SmallVector<Type> innerWarpInputType (forOp.getResultTypes ().begin (),
1866+ forOp.getResultTypes ().end ());
1867+ // Escaping values are forwarded to the inner warp op as its (additional)
1868+ // arguments. We keep track of the mapping between these values and their
1869+ // argument index in the inner warp op (to replcace uses later).
18851870 llvm::SmallDenseMap<Value, int64_t > argIndexMapping;
1886- // llvm::errs() << "setting arg index mapping\n";
1887- unsigned escapingValuesStartIdx = forOp.getInitArgs ().size ();
18881871 for (size_t i = escapingValuesStartIdx;
18891872 i < escapingValuesStartIdx + escapingValues.size (); ++i) {
1890- warpInput .push_back (newWarpOp.getResult (i));
1873+ innerWarpInput .push_back (newWarpOp.getResult (i));
18911874 argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
1892- warpInputType.size ();
1893- warpInputType.push_back (inputTypes[i - escapingValuesStartIdx]);
1875+ innerWarpInputType.size ();
1876+ innerWarpInputType.push_back (
1877+ escapingValueInputTypes[i - escapingValuesStartIdx]);
18941878 }
1895- // for (auto [i, r] : llvm::enumerate(
1896- // newWarpOp.getResults().drop_front(forOp.getInitArgs().size())))
1897- // {
1898- // warpInput.push_back(r);
1899- // argIndexMapping[escapingValues[i]] = warpInputType.size();
1900- // warpInputType.push_back(inputTypes[i]);
1901- // }
1902- // llvm::errs() << "go here\n";
1879+ // Create the inner warp op with the new input values and types.
19031880 auto innerWarp = rewriter.create <WarpExecuteOnLane0Op>(
19041881 newWarpOp.getLoc (), newForOp.getResultTypes (), newWarpOp.getLaneid (),
1905- newWarpOp.getWarpSize (), warpInput, warpInputType);
1906- // newForOp->getParentOp()->print(llvm::outs());
1907- // llvm::outs() << "\n";
1882+ newWarpOp.getWarpSize (), innerWarpInput, innerWarpInputType);
19081883
1884+ // Inline the for op body into the inner warp op body.
19091885 SmallVector<Value> argMapping;
19101886 argMapping.push_back (newForOp.getInductionVar ());
1911- for (Value args : innerWarp.getBody ()->getArguments ()) {
1887+ for (Value args : innerWarp.getBody ()->getArguments ())
19121888 argMapping.push_back (args);
1913- }
1914- auto forOpCopy = cast<scf::ForOp>(rewriter.clone (*forOp.getOperation ()));
1915- argMapping.resize (forOpCopy.getBody ()->getNumArguments ());
1889+
1890+ argMapping.resize (forOp.getBody ()->getNumArguments ());
19161891 SmallVector<Value> yieldOperands;
1917- for (Value operand : forOpCopy .getBody ()->getTerminator ()->getOperands ())
1892+ for (Value operand : forOp .getBody ()->getTerminator ()->getOperands ())
19181893 yieldOperands.push_back (operand);
19191894
1920- rewriter.eraseOp (forOpCopy.getBody ()->getTerminator ());
1921- rewriter.mergeBlocks (forOpCopy.getBody (), innerWarp.getBody (), argMapping);
1895+ rewriter.eraseOp (forOp.getBody ()->getTerminator ());
1896+ rewriter.mergeBlocks (forOp.getBody (), innerWarp.getBody (), argMapping);
1897+
1898+ // Insert a gpu yieldOp at the end of the inner warp op body that yields
1899+ // original forOp results.
19221900 rewriter.setInsertionPointToEnd (innerWarp.getBody ());
19231901 rewriter.create <gpu::YieldOp>(innerWarp.getLoc (), yieldOperands);
19241902 rewriter.setInsertionPointAfter (innerWarp);
1903+ // Insert a scf.yield op at the end of the new for op body that yields
1904+ // the inner warp op results.
19251905 if (!innerWarp.getResults ().empty ())
1926- rewriter.create <scf::YieldOp>(forOpCopy.getLoc (), innerWarp.getResults ());
1927- // forOpCopy->getParentOp()->getParentOp()->print(llvm::outs());
1928- // llvm::outs() << "\n";
1929- // llvm::errs() << "erasing for op\n";
1930-
1931- rewriter.eraseOp (forOpCopy);
1932- // Replace the warpOp result coming from the original ForOp.
1933- // print resultIdx for debugging.
1934- // llvm::errs() << "resultIdx: ";
1935- // for (auto idx : resultIdx)
1936- // llvm::errs() << idx << " ";
1937- // llvm::errs() << "\n";
1938- for (auto [origIdx, newIdx] : forResultMapping) {
1906+ rewriter.create <scf::YieldOp>(forOp.getLoc (), innerWarp.getResults ());
1907+
1908+ // Update the users of original warp op results that were coming from the
1909+ // original forOp to the corresponding new forOp result.
1910+ for (auto [origIdx, newIdx] : forResultMapping)
19391911 rewriter.replaceAllUsesExcept (warpOp.getResult (origIdx),
19401912 newForOp.getResult (newIdx), newForOp);
1941- // newForOp->setOperand(res.index() + 3,
1942- // newWarpOp.getResult(res.value()));
1943- }
1944-
1945- for (auto [origIdx, newIdx] : warpResultMapping) {
1913+ // Similarly, update any users of the warp op results that were not
1914+ // results of the forOp.
1915+ for (auto [origIdx, newIdx] : warpResultMapping)
19461916 rewriter.replaceAllUsesWith (warpOp.getResult (origIdx),
19471917 newWarpOp.getResult (newIdx));
1948- // newForOp->setOperand(res.index() + 3,
1949- // newWarpOp.getResult(res.value()));
1950- }
1918+ // Remove the original warp op and for op, they should not have any uses
1919+ // at this point.
19511920 rewriter.eraseOp (forOp);
19521921 rewriter.eraseOp (warpOp);
1922+ // Update any users of escaping values that were forwarded to the
1923+ // inner warp op. These values are now arguments of the inner warp op.
19531924 newForOp.walk ([&](Operation *op) {
19541925 for (OpOperand &operand : op->getOpOperands ()) {
19551926 auto it = argIndexMapping.find (operand.get ());
@@ -1958,8 +1929,6 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
19581929 operand.set (innerWarp.getBodyRegion ().getArgument (it->second ));
19591930 }
19601931 });
1961- // newForOp->getParentOp()->print(llvm::outs());
1962- // llvm::outs() << "\n";
19631932
19641933 // Finally, hoist out any now uniform code from the inner warp op.
19651934 mlir::vector::moveScalarUniformCode (innerWarp);
0 commit comments