Skip to content

Commit 3595f17

Browse files
committed
working version refined
1 parent 4c36317 commit 3595f17

File tree

1 file changed

+52
-23
lines changed

1 file changed

+52
-23
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/IR/Value.h"
2121
#include "mlir/Interfaces/SideEffectInterfaces.h"
2222
#include "mlir/Transforms/RegionUtils.h"
23+
#include "llvm/ADT/DenseMap.h"
2324
#include "llvm/ADT/SetVector.h"
2425
#include "llvm/ADT/SmallVectorExtras.h"
2526
#include "llvm/Support/FormatVariadic.h"
@@ -1778,37 +1779,53 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17781779
if (llvm::is_contained(distTypes, Type{}))
17791780
return failure();
17801781

1781-
llvm::errs() << "escpaing values size: " << escapingValues.size() << "\n";
1782+
// record result mapping.
1783+
SmallVector<unsigned> resultIdx;
1784+
llvm::SmallDenseMap<unsigned, unsigned> forResultToWarpResultMapping;
1785+
for (OpOperand &yieldOperand : yield->getOpOperands()) {
1786+
if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
1787+
continue;
1788+
OpResult forResult = cast<OpResult>(yieldOperand.get());
1789+
resultIdx.push_back(forResult.getResultNumber());
1790+
forResultToWarpResultMapping[forResult.getResultNumber()] =
1791+
yieldOperand.getOperandNumber();
1792+
// yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
1793+
}
1794+
1795+
// llvm::errs() << "escpaing values size: " << escapingValues.size() <<
1796+
// "\n";
17821797

17831798
SmallVector<Value> yieldedValuesFromWarpOp;
1799+
SmallVector<Type> yieldedTypesFromWarpOp;
17841800
// All init args of the forOp are yielded from the original warp op.
1785-
for (Value initArg : forOp.getInitArgs()) {
1801+
for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
17861802
yieldedValuesFromWarpOp.push_back(initArg);
17871803
// find distributed type for the init arg.
17881804
Type distType = initArg.getType();
17891805
if (auto vecType = dyn_cast<VectorType>(distType)) {
1790-
AffineMap map = distributionMapFn(initArg);
1791-
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1806+
if (forResultToWarpResultMapping.contains(i)) {
1807+
// If the init arg is yielded from the warp op, we need to compute the
1808+
// distributed type.
1809+
distType =
1810+
warpOp.getResult(forResultToWarpResultMapping[i]).getType();
1811+
} else {
1812+
AffineMap map = distributionMapFn(initArg);
1813+
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1814+
}
17921815
}
1793-
distTypes.push_back(distType);
1816+
// llvm::errs() << "distributed type: " << distType << "\n";
1817+
yieldedTypesFromWarpOp.push_back(distType);
17941818
}
17951819
// All escaping values are yielded from the original warp op.
17961820
yieldedValuesFromWarpOp.insert(yieldedValuesFromWarpOp.end(),
17971821
escapingValues.begin(),
17981822
escapingValues.end());
1799-
// record result mapping.
1800-
SmallVector<unsigned> resultIdx;
1801-
for (OpOperand &yieldOperand : yield->getOpOperands()) {
1802-
if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
1803-
continue;
1804-
OpResult forResult = cast<OpResult>(yieldOperand.get());
1805-
resultIdx.push_back(forResult.getResultNumber());
1806-
// yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
1807-
}
1823+
yieldedTypesFromWarpOp.insert(yieldedTypesFromWarpOp.end(),
1824+
distTypes.begin(), distTypes.end());
18081825

18091826
// SmallVector<size_t> newRetIndices;
18101827
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
1811-
rewriter, warpOp, yieldedValuesFromWarpOp, distTypes);
1828+
rewriter, warpOp, yieldedValuesFromWarpOp, yieldedTypesFromWarpOp);
18121829
yield = cast<gpu::YieldOp>(
18131830
newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
18141831

@@ -1839,15 +1856,27 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18391856
SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
18401857
forOp.getResultTypes().end());
18411858
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1859+
// llvm::errs() << "setting arg index mapping\n";
18421860
for (size_t i = forOp.getInitArgs().size(); i < newWarpOp->getNumResults();
18431861
++i) {
18441862
warpInput.push_back(newWarpOp.getResult(i));
1845-
argIndexMapping[escapingValues[i]] = warpInputType.size();
1846-
warpInputType.push_back(inputTypes[i]);
1863+
argIndexMapping[escapingValues[i - forOp.getInitArgs().size()]] =
1864+
warpInputType.size();
1865+
warpInputType.push_back(inputTypes[i - forOp.getInitArgs().size()]);
18471866
}
1867+
// for (auto [i, r] : llvm::enumerate(
1868+
// newWarpOp.getResults().drop_front(forOp.getInitArgs().size())))
1869+
// {
1870+
// warpInput.push_back(r);
1871+
// argIndexMapping[escapingValues[i]] = warpInputType.size();
1872+
// warpInputType.push_back(inputTypes[i]);
1873+
// }
1874+
// llvm::errs() << "go here\n";
18481875
auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
18491876
newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
18501877
newWarpOp.getWarpSize(), warpInput, warpInputType);
1878+
// newForOp->getParentOp()->print(llvm::outs());
1879+
// llvm::outs() << "\n";
18511880

18521881
SmallVector<Value> argMapping;
18531882
argMapping.push_back(newForOp.getInductionVar());
@@ -1874,10 +1903,10 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18741903
rewriter.eraseOp(forOpCopy);
18751904
// Replace the warpOp result coming from the original ForOp.
18761905
// print resultIdx for debugging.
1877-
llvm::errs() << "resultIdx: ";
1878-
for (auto idx : resultIdx)
1879-
llvm::errs() << idx << " ";
1880-
llvm::errs() << "\n";
1906+
// llvm::errs() << "resultIdx: ";
1907+
// for (auto idx : resultIdx)
1908+
// llvm::errs() << idx << " ";
1909+
// llvm::errs() << "\n";
18811910
for (const auto &res : llvm::enumerate(resultIdx)) {
18821911
rewriter.replaceAllUsesExcept(warpOp.getResult(res.value()),
18831912
newForOp.getResult(res.index()), newForOp);
@@ -1894,8 +1923,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18941923
operand.set(innerWarp.getBodyRegion().getArgument(it->second));
18951924
}
18961925
});
1897-
newForOp->getParentOp()->print(llvm::outs());
1898-
llvm::outs() << "\n";
1926+
// newForOp->getParentOp()->print(llvm::outs());
1927+
// llvm::outs() << "\n";
18991928

19001929
// Finally, hoist out any now uniform code from the inner warp op.
19011930
mlir::vector::moveScalarUniformCode(innerWarp);

0 commit comments

Comments
 (0)