2020#include " mlir/IR/Value.h"
2121#include " mlir/Interfaces/SideEffectInterfaces.h"
2222#include " mlir/Transforms/RegionUtils.h"
23+ #include " llvm/ADT/DenseMap.h"
2324#include " llvm/ADT/SetVector.h"
2425#include " llvm/ADT/SmallVectorExtras.h"
2526#include " llvm/Support/FormatVariadic.h"
@@ -1778,37 +1779,53 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17781779 if (llvm::is_contained (distTypes, Type{}))
17791780 return failure ();
17801781
1781- llvm::errs () << " escpaing values size: " << escapingValues.size () << " \n " ;
1782+ // record result mapping.
1783+ SmallVector<unsigned > resultIdx;
1784+ llvm::SmallDenseMap<unsigned , unsigned > forResultToWarpResultMapping;
1785+ for (OpOperand &yieldOperand : yield->getOpOperands ()) {
1786+ if (yieldOperand.get ().getDefiningOp () != forOp.getOperation ())
1787+ continue ;
1788+ OpResult forResult = cast<OpResult>(yieldOperand.get ());
1789+ resultIdx.push_back (forResult.getResultNumber ());
1790+ forResultToWarpResultMapping[forResult.getResultNumber ()] =
1791+ yieldOperand.getOperandNumber ();
1792+ // yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
1793+ }
1794+
1795+ // llvm::errs() << "escpaing values size: " << escapingValues.size() <<
1796+ // "\n";
17821797
17831798 SmallVector<Value> yieldedValuesFromWarpOp;
1799+ SmallVector<Type> yieldedTypesFromWarpOp;
17841800 // All init args of the forOp are yielded from the original warp op.
1785- for (Value initArg : forOp.getInitArgs ()) {
1801+ for (auto [i, initArg] : llvm::enumerate ( forOp.getInitArgs () )) {
17861802 yieldedValuesFromWarpOp.push_back (initArg);
17871803 // find distributed type for the init arg.
17881804 Type distType = initArg.getType ();
17891805 if (auto vecType = dyn_cast<VectorType>(distType)) {
1790- AffineMap map = distributionMapFn (initArg);
1791- distType = getDistributedType (vecType, map, warpOp.getWarpSize ());
1806+ if (forResultToWarpResultMapping.contains (i)) {
1807+ // If the init arg is yielded from the warp op, we need to compute the
1808+ // distributed type.
1809+ distType =
1810+ warpOp.getResult (forResultToWarpResultMapping[i]).getType ();
1811+ } else {
1812+ AffineMap map = distributionMapFn (initArg);
1813+ distType = getDistributedType (vecType, map, warpOp.getWarpSize ());
1814+ }
17921815 }
1793- distTypes.push_back (distType);
1816+ // llvm::errs() << "distributed type: " << distType << "\n";
1817+ yieldedTypesFromWarpOp.push_back (distType);
17941818 }
17951819 // All escaping values are yielded from the original warp op.
17961820 yieldedValuesFromWarpOp.insert (yieldedValuesFromWarpOp.end (),
17971821 escapingValues.begin (),
17981822 escapingValues.end ());
1799- // record result mapping.
1800- SmallVector<unsigned > resultIdx;
1801- for (OpOperand &yieldOperand : yield->getOpOperands ()) {
1802- if (yieldOperand.get ().getDefiningOp () != forOp.getOperation ())
1803- continue ;
1804- OpResult forResult = cast<OpResult>(yieldOperand.get ());
1805- resultIdx.push_back (forResult.getResultNumber ());
1806- // yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
1807- }
1823+ yieldedTypesFromWarpOp.insert (yieldedTypesFromWarpOp.end (),
1824+ distTypes.begin (), distTypes.end ());
18081825
18091826 // SmallVector<size_t> newRetIndices;
18101827 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns (
1811- rewriter, warpOp, yieldedValuesFromWarpOp, distTypes );
1828+ rewriter, warpOp, yieldedValuesFromWarpOp, yieldedTypesFromWarpOp );
18121829 yield = cast<gpu::YieldOp>(
18131830 newWarpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
18141831
@@ -1839,15 +1856,27 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18391856 SmallVector<Type> warpInputType (forOp.getResultTypes ().begin (),
18401857 forOp.getResultTypes ().end ());
18411858 llvm::SmallDenseMap<Value, int64_t > argIndexMapping;
1859+ // llvm::errs() << "setting arg index mapping\n";
18421860 for (size_t i = forOp.getInitArgs ().size (); i < newWarpOp->getNumResults ();
18431861 ++i) {
18441862 warpInput.push_back (newWarpOp.getResult (i));
1845- argIndexMapping[escapingValues[i]] = warpInputType.size ();
1846- warpInputType.push_back (inputTypes[i]);
1863+ argIndexMapping[escapingValues[i - forOp.getInitArgs ().size ()]] =
1864+ warpInputType.size ();
1865+ warpInputType.push_back (inputTypes[i - forOp.getInitArgs ().size ()]);
18471866 }
1867+ // for (auto [i, r] : llvm::enumerate(
1868+ // newWarpOp.getResults().drop_front(forOp.getInitArgs().size())))
1869+ // {
1870+ // warpInput.push_back(r);
1871+ // argIndexMapping[escapingValues[i]] = warpInputType.size();
1872+ // warpInputType.push_back(inputTypes[i]);
1873+ // }
1874+ // llvm::errs() << "go here\n";
18481875 auto innerWarp = rewriter.create <WarpExecuteOnLane0Op>(
18491876 newWarpOp.getLoc (), newForOp.getResultTypes (), newWarpOp.getLaneid (),
18501877 newWarpOp.getWarpSize (), warpInput, warpInputType);
1878+ // newForOp->getParentOp()->print(llvm::outs());
1879+ // llvm::outs() << "\n";
18511880
18521881 SmallVector<Value> argMapping;
18531882 argMapping.push_back (newForOp.getInductionVar ());
@@ -1874,10 +1903,10 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18741903 rewriter.eraseOp (forOpCopy);
18751904 // Replace the warpOp result coming from the original ForOp.
18761905 // print resultIdx for debugging.
1877- llvm::errs () << " resultIdx: " ;
1878- for (auto idx : resultIdx)
1879- llvm::errs () << idx << " " ;
1880- llvm::errs () << " \n " ;
1906+ // llvm::errs() << "resultIdx: ";
1907+ // for (auto idx : resultIdx)
1908+ // llvm::errs() << idx << " ";
1909+ // llvm::errs() << "\n";
18811910 for (const auto &res : llvm::enumerate (resultIdx)) {
18821911 rewriter.replaceAllUsesExcept (warpOp.getResult (res.value ()),
18831912 newForOp.getResult (res.index ()), newForOp);
@@ -1894,8 +1923,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18941923 operand.set (innerWarp.getBodyRegion ().getArgument (it->second ));
18951924 }
18961925 });
1897- newForOp->getParentOp ()->print (llvm::outs ());
1898- llvm::outs () << " \n " ;
1926+ // newForOp->getParentOp()->print(llvm::outs());
1927+ // llvm::outs() << "\n";
18991928
19001929 // Finally, hoist out any now uniform code from the inner warp op.
19011930 mlir::vector::moveScalarUniformCode (innerWarp);
0 commit comments