1717#include " mlir/IR/AffineExpr.h"
1818#include " mlir/IR/Attributes.h"
1919#include " mlir/IR/BuiltinTypes.h"
20+ #include " mlir/IR/Value.h"
2021#include " mlir/Interfaces/SideEffectInterfaces.h"
2122#include " mlir/Transforms/RegionUtils.h"
2223#include " llvm/ADT/SetVector.h"
@@ -1777,24 +1778,42 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17771778 if (llvm::is_contained (distTypes, Type{}))
17781779 return failure ();
17791780
1781+ llvm::errs () << " escpaing values size: " << escapingValues.size () << " \n " ;
1782+
1783+ SmallVector<Value> yieldedValuesFromWarpOp;
1784+ // All init args of the forOp are yielded from the original warp op.
1785+ for (Value initArg : forOp.getInitArgs ()) {
1786+ yieldedValuesFromWarpOp.push_back (initArg);
1787+ // find distributed type for the init arg.
1788+ Type distType = initArg.getType ();
1789+ if (auto vecType = dyn_cast<VectorType>(distType)) {
1790+ AffineMap map = distributionMapFn (initArg);
1791+ distType = getDistributedType (vecType, map, warpOp.getWarpSize ());
1792+ }
1793+ distTypes.push_back (distType);
1794+ }
1795+ // All escaping values are yielded from the original warp op.
1796+ yieldedValuesFromWarpOp.insert (yieldedValuesFromWarpOp.end (),
1797+ escapingValues.begin (),
1798+ escapingValues.end ());
1799+
17801800 SmallVector<size_t > newRetIndices;
17811801 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
1782- rewriter, warpOp, escapingValues.getArrayRef (), distTypes,
1783- newRetIndices);
1802+ rewriter, warpOp, yieldedValuesFromWarpOp, distTypes, newRetIndices);
17841803 yield = cast<gpu::YieldOp>(
17851804 newWarpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
17861805
17871806 SmallVector<Value> newOperands;
17881807 SmallVector<unsigned > resultIdx;
1789- // Collect all the outputs coming from the forOp.
1808+ // Collect the new init args coming from the new warp op.
1809+ for (size_t i = 0 ; i < forOp.getInitArgs ().size (); ++i)
1810+ newOperands.push_back (newWarpOp.getResult (newRetIndices[i]));
17901811 for (OpOperand &yieldOperand : yield->getOpOperands ()) {
17911812 if (yieldOperand.get ().getDefiningOp () != forOp.getOperation ())
17921813 continue ;
1793- auto forResult = cast<OpResult>(yieldOperand.get ());
1794- newOperands.push_back (
1795- newWarpOp.getResult (yieldOperand.getOperandNumber ()));
1814+ OpResult forResult = cast<OpResult>(yieldOperand.get ());
1815+ resultIdx.push_back (forResult.getResultNumber ());
17961816 yieldOperand.set (forOp.getInitArgs ()[forResult.getResultNumber ()]);
1797- resultIdx.push_back (yieldOperand.getOperandNumber ());
17981817 }
17991818
18001819 OpBuilder::InsertionGuard g (rewriter);
@@ -1812,8 +1831,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18121831 SmallVector<Type> warpInputType (forOp.getResultTypes ().begin (),
18131832 forOp.getResultTypes ().end ());
18141833 llvm::SmallDenseMap<Value, int64_t > argIndexMapping;
1815- for (auto [i, retIdx] : llvm::enumerate ( newRetIndices) ) {
1816- warpInput.push_back (newWarpOp.getResult (retIdx ));
1834+ for (size_t i = forOp. getInitArgs (). size (); i < newRetIndices. size (); ++i ) {
1835+ warpInput.push_back (newWarpOp.getResult (i ));
18171836 argIndexMapping[escapingValues[i]] = warpInputType.size ();
18181837 warpInputType.push_back (inputTypes[i]);
18191838 }
@@ -1826,24 +1845,37 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18261845 for (Value args : innerWarp.getBody ()->getArguments ()) {
18271846 argMapping.push_back (args);
18281847 }
1829- argMapping.resize (forOp.getBody ()->getNumArguments ());
1848+ auto forOpCopy = cast<scf::ForOp>(rewriter.clone (*forOp.getOperation ()));
1849+ argMapping.resize (forOpCopy.getBody ()->getNumArguments ());
18301850 SmallVector<Value> yieldOperands;
1831- for (Value operand : forOp .getBody ()->getTerminator ()->getOperands ())
1851+ for (Value operand : forOpCopy .getBody ()->getTerminator ()->getOperands ())
18321852 yieldOperands.push_back (operand);
1833- rewriter.eraseOp (forOp.getBody ()->getTerminator ());
1834- rewriter.mergeBlocks (forOp.getBody (), innerWarp.getBody (), argMapping);
1853+
1854+ rewriter.eraseOp (forOpCopy.getBody ()->getTerminator ());
1855+ rewriter.mergeBlocks (forOpCopy.getBody (), innerWarp.getBody (), argMapping);
18351856 rewriter.setInsertionPointToEnd (innerWarp.getBody ());
18361857 rewriter.create <gpu::YieldOp>(innerWarp.getLoc (), yieldOperands);
18371858 rewriter.setInsertionPointAfter (innerWarp);
18381859 if (!innerWarp.getResults ().empty ())
1839- rewriter.create <scf::YieldOp>(forOp.getLoc (), innerWarp.getResults ());
1840- rewriter.eraseOp (forOp);
1860+ rewriter.create <scf::YieldOp>(forOpCopy.getLoc (), innerWarp.getResults ());
1861+ // forOpCopy->getParentOp()->getParentOp()->print(llvm::outs());
1862+ // llvm::outs() << "\n";
1863+ // llvm::errs() << "erasing for op\n";
1864+
1865+ rewriter.eraseOp (forOpCopy);
18411866 // Replace the warpOp result coming from the original ForOp.
1867+ // print resultIdx for debugging.
1868+ llvm::errs () << " resultIdx: " ;
1869+ for (auto idx : resultIdx)
1870+ llvm::errs () << idx << " " ;
1871+ llvm::errs () << " \n " ;
18421872 for (const auto &res : llvm::enumerate (resultIdx)) {
18431873 rewriter.replaceAllUsesWith (newWarpOp.getResult (res.value ()),
18441874 newForOp.getResult (res.index ()));
1845- newForOp->setOperand (res.index () + 3 , newWarpOp.getResult (res.value ()));
1875+ // newForOp->setOperand(res.index() + 3,
1876+ // newWarpOp.getResult(res.value()));
18461877 }
1878+ rewriter.eraseOp (forOp);
18471879 newForOp.walk ([&](Operation *op) {
18481880 for (OpOperand &operand : op->getOpOperands ()) {
18491881 auto it = argIndexMapping.find (operand.get ());
@@ -1852,6 +1884,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18521884 operand.set (innerWarp.getBodyRegion ().getArgument (it->second ));
18531885 }
18541886 });
1887+ newForOp->getParentOp ()->print (llvm::outs ());
1888+ llvm::outs () << " \n " ;
18551889
18561890 // Finally, hoist out any now uniform code from the inner warp op.
18571891 mlir::vector::moveScalarUniformCode (innerWarp);
0 commit comments