Skip to content

Commit 4c36317

Browse files
committed
working version
1 parent 5868390 commit 4c36317

File tree

1 file changed

+25
-15
lines changed

1 file changed

+25
-15
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1796,26 +1796,34 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17961796
yieldedValuesFromWarpOp.insert(yieldedValuesFromWarpOp.end(),
17971797
escapingValues.begin(),
17981798
escapingValues.end());
1799-
1800-
SmallVector<size_t> newRetIndices;
1801-
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1802-
rewriter, warpOp, yieldedValuesFromWarpOp, distTypes, newRetIndices);
1803-
yield = cast<gpu::YieldOp>(
1804-
newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1805-
1806-
SmallVector<Value> newOperands;
1799+
// record result mapping.
18071800
SmallVector<unsigned> resultIdx;
1808-
// Collect the new init args coming from the new warp op.
1809-
for (size_t i = 0; i < forOp.getInitArgs().size(); ++i)
1810-
newOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
18111801
for (OpOperand &yieldOperand : yield->getOpOperands()) {
18121802
if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
18131803
continue;
18141804
OpResult forResult = cast<OpResult>(yieldOperand.get());
18151805
resultIdx.push_back(forResult.getResultNumber());
1816-
yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
1806+
// yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
18171807
}
18181808

1809+
// SmallVector<size_t> newRetIndices;
1810+
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
1811+
rewriter, warpOp, yieldedValuesFromWarpOp, distTypes);
1812+
yield = cast<gpu::YieldOp>(
1813+
newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1814+
1815+
SmallVector<Value> newOperands;
1816+
// Collect the new init args coming from the new warp op.
1817+
for (size_t i = 0; i < forOp.getInitArgs().size(); ++i)
1818+
newOperands.push_back(newWarpOp.getResult(i));
1819+
// for (OpOperand &yieldOperand : yield->getOpOperands()) {
1820+
// if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
1821+
// continue;
1822+
// OpResult forResult = cast<OpResult>(yieldOperand.get());
1823+
// resultIdx.push_back(forResult.getResultNumber());
1824+
// yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
1825+
// }
1826+
18191827
OpBuilder::InsertionGuard g(rewriter);
18201828
rewriter.setInsertionPointAfter(newWarpOp);
18211829

@@ -1831,7 +1839,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18311839
SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
18321840
forOp.getResultTypes().end());
18331841
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1834-
for (size_t i = forOp.getInitArgs().size(); i < newRetIndices.size(); ++i) {
1842+
for (size_t i = forOp.getInitArgs().size(); i < newWarpOp->getNumResults();
1843+
++i) {
18351844
warpInput.push_back(newWarpOp.getResult(i));
18361845
argIndexMapping[escapingValues[i]] = warpInputType.size();
18371846
warpInputType.push_back(inputTypes[i]);
@@ -1870,12 +1879,13 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18701879
llvm::errs() << idx << " ";
18711880
llvm::errs() << "\n";
18721881
for (const auto &res : llvm::enumerate(resultIdx)) {
1873-
rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
1874-
newForOp.getResult(res.index()));
1882+
rewriter.replaceAllUsesExcept(warpOp.getResult(res.value()),
1883+
newForOp.getResult(res.index()), newForOp);
18751884
// newForOp->setOperand(res.index() + 3,
18761885
// newWarpOp.getResult(res.value()));
18771886
}
18781887
rewriter.eraseOp(forOp);
1888+
rewriter.eraseOp(warpOp);
18791889
newForOp.walk([&](Operation *op) {
18801890
for (OpOperand &operand : op->getOpOperands()) {
18811891
auto it = argIndexMapping.find(operand.get());

0 commit comments

Comments
 (0)