Skip to content

Commit 5868390

Browse files
committed
working but bug in dead result
1 parent c539ec0 commit 5868390

File tree

1 file changed

+50
-16
lines changed

1 file changed

+50
-16
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/IR/AffineExpr.h"
1818
#include "mlir/IR/Attributes.h"
1919
#include "mlir/IR/BuiltinTypes.h"
20+
#include "mlir/IR/Value.h"
2021
#include "mlir/Interfaces/SideEffectInterfaces.h"
2122
#include "mlir/Transforms/RegionUtils.h"
2223
#include "llvm/ADT/SetVector.h"
@@ -1777,24 +1778,42 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17771778
if (llvm::is_contained(distTypes, Type{}))
17781779
return failure();
17791780

1781+
llvm::errs() << "escpaing values size: " << escapingValues.size() << "\n";
1782+
1783+
SmallVector<Value> yieldedValuesFromWarpOp;
1784+
// All init args of the forOp are yielded from the original warp op.
1785+
for (Value initArg : forOp.getInitArgs()) {
1786+
yieldedValuesFromWarpOp.push_back(initArg);
1787+
// find distributed type for the init arg.
1788+
Type distType = initArg.getType();
1789+
if (auto vecType = dyn_cast<VectorType>(distType)) {
1790+
AffineMap map = distributionMapFn(initArg);
1791+
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1792+
}
1793+
distTypes.push_back(distType);
1794+
}
1795+
// All escaping values are yielded from the original warp op.
1796+
yieldedValuesFromWarpOp.insert(yieldedValuesFromWarpOp.end(),
1797+
escapingValues.begin(),
1798+
escapingValues.end());
1799+
17801800
SmallVector<size_t> newRetIndices;
17811801
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1782-
rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
1783-
newRetIndices);
1802+
rewriter, warpOp, yieldedValuesFromWarpOp, distTypes, newRetIndices);
17841803
yield = cast<gpu::YieldOp>(
17851804
newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
17861805

17871806
SmallVector<Value> newOperands;
17881807
SmallVector<unsigned> resultIdx;
1789-
// Collect all the outputs coming from the forOp.
1808+
// Collect the new init args coming from the new warp op.
1809+
for (size_t i = 0; i < forOp.getInitArgs().size(); ++i)
1810+
newOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
17901811
for (OpOperand &yieldOperand : yield->getOpOperands()) {
17911812
if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
17921813
continue;
1793-
auto forResult = cast<OpResult>(yieldOperand.get());
1794-
newOperands.push_back(
1795-
newWarpOp.getResult(yieldOperand.getOperandNumber()));
1814+
OpResult forResult = cast<OpResult>(yieldOperand.get());
1815+
resultIdx.push_back(forResult.getResultNumber());
17961816
yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
1797-
resultIdx.push_back(yieldOperand.getOperandNumber());
17981817
}
17991818

18001819
OpBuilder::InsertionGuard g(rewriter);
@@ -1812,8 +1831,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18121831
SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
18131832
forOp.getResultTypes().end());
18141833
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1815-
for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
1816-
warpInput.push_back(newWarpOp.getResult(retIdx));
1834+
for (size_t i = forOp.getInitArgs().size(); i < newRetIndices.size(); ++i) {
1835+
warpInput.push_back(newWarpOp.getResult(i));
18171836
argIndexMapping[escapingValues[i]] = warpInputType.size();
18181837
warpInputType.push_back(inputTypes[i]);
18191838
}
@@ -1826,24 +1845,37 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18261845
for (Value args : innerWarp.getBody()->getArguments()) {
18271846
argMapping.push_back(args);
18281847
}
1829-
argMapping.resize(forOp.getBody()->getNumArguments());
1848+
auto forOpCopy = cast<scf::ForOp>(rewriter.clone(*forOp.getOperation()));
1849+
argMapping.resize(forOpCopy.getBody()->getNumArguments());
18301850
SmallVector<Value> yieldOperands;
1831-
for (Value operand : forOp.getBody()->getTerminator()->getOperands())
1851+
for (Value operand : forOpCopy.getBody()->getTerminator()->getOperands())
18321852
yieldOperands.push_back(operand);
1833-
rewriter.eraseOp(forOp.getBody()->getTerminator());
1834-
rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
1853+
1854+
rewriter.eraseOp(forOpCopy.getBody()->getTerminator());
1855+
rewriter.mergeBlocks(forOpCopy.getBody(), innerWarp.getBody(), argMapping);
18351856
rewriter.setInsertionPointToEnd(innerWarp.getBody());
18361857
rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
18371858
rewriter.setInsertionPointAfter(innerWarp);
18381859
if (!innerWarp.getResults().empty())
1839-
rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
1840-
rewriter.eraseOp(forOp);
1860+
rewriter.create<scf::YieldOp>(forOpCopy.getLoc(), innerWarp.getResults());
1861+
// forOpCopy->getParentOp()->getParentOp()->print(llvm::outs());
1862+
// llvm::outs() << "\n";
1863+
// llvm::errs() << "erasing for op\n";
1864+
1865+
rewriter.eraseOp(forOpCopy);
18411866
// Replace the warpOp result coming from the original ForOp.
1867+
// print resultIdx for debugging.
1868+
llvm::errs() << "resultIdx: ";
1869+
for (auto idx : resultIdx)
1870+
llvm::errs() << idx << " ";
1871+
llvm::errs() << "\n";
18421872
for (const auto &res : llvm::enumerate(resultIdx)) {
18431873
rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
18441874
newForOp.getResult(res.index()));
1845-
newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
1875+
// newForOp->setOperand(res.index() + 3,
1876+
// newWarpOp.getResult(res.value()));
18461877
}
1878+
rewriter.eraseOp(forOp);
18471879
newForOp.walk([&](Operation *op) {
18481880
for (OpOperand &operand : op->getOpOperands()) {
18491881
auto it = argIndexMapping.find(operand.get());
@@ -1852,6 +1884,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18521884
operand.set(innerWarp.getBodyRegion().getArgument(it->second));
18531885
}
18541886
});
1887+
newForOp->getParentOp()->print(llvm::outs());
1888+
llvm::outs() << "\n";
18551889

18561890
// Finally, hoist out any now uniform code from the inner warp op.
18571891
mlir::vector::moveScalarUniformCode(innerWarp);

0 commit comments

Comments
 (0)