Skip to content

Commit 28ef9c9

Browse files
committed
add comments and tests
1 parent ba94ee2 commit 28ef9c9

File tree

2 files changed

+172
-124
lines changed

2 files changed

+172
-124
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 93 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -1749,19 +1749,18 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17491749
: WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
17501750
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
17511751
PatternRewriter &rewriter) const override {
1752-
auto yield = cast<gpu::YieldOp>(
1752+
auto newWarpOpYield = cast<gpu::YieldOp>(
17531753
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
17541754
// Only pick up forOp if it is the last op in the region.
1755-
Operation *lastNode = yield->getPrevNode();
1755+
Operation *lastNode = newWarpOpYield->getPrevNode();
17561756
auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
17571757
if (!forOp)
17581758
return failure();
17591759
// Collect Values that come from the warp op but are outside the forOp.
1760-
// Those Value needs to be returned by the original warpOp and passed to
1761-
// the new op.
1760+
// Those Value needs to be returned by the new warp op.
17621761
llvm::SmallSetVector<Value, 32> escapingValues;
1763-
SmallVector<Type> inputTypes;
1764-
SmallVector<Type> distTypes;
1762+
SmallVector<Type> escapingValueInputTypes;
1763+
SmallVector<Type> escapingValuedistTypes;
17651764
mlir::visitUsedValuesDefinedAbove(
17661765
forOp.getBodyRegion(), [&](OpOperand *operand) {
17671766
Operation *parent = operand->get().getParentRegion()->getParentOp();
@@ -1773,183 +1772,155 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17731772
AffineMap map = distributionMapFn(operand->get());
17741773
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
17751774
}
1776-
inputTypes.push_back(operand->get().getType());
1777-
distTypes.push_back(distType);
1775+
escapingValueInputTypes.push_back(operand->get().getType());
1776+
escapingValuedistTypes.push_back(distType);
17781777
}
17791778
});
17801779

1781-
if (llvm::is_contained(distTypes, Type{}))
1780+
if (llvm::is_contained(escapingValuedistTypes, Type{}))
17821781
return failure();
1783-
1782+
// Warp op can yield two types of values:
1783+
// 1. Values that are not results of the forOp:
1784+
// These values must also be yielded by the new warp op. Also, we need to
1785+
// record the index mapping for these values to replace them later.
1786+
// 2. Values that are results of the forOp:
1787+
// In this case, we record the index mapping between the warp op result
1788+
// index and matching forOp result index.
17841789
SmallVector<Value> nonForYieldedValues;
1785-
// SmallVector<Type> nonForYieldedTypes;
17861790
SmallVector<unsigned> nonForResultIndices;
1787-
1788-
// record result mapping.
17891791
DenseMap<unsigned, unsigned> forResultMapping;
1790-
DenseMap<unsigned, unsigned> warpResultMapping;
1791-
// llvm::SmallDenseMap<unsigned, unsigned> forResultToWarpResultMapping;
1792-
for (OpOperand &yieldOperand : yield->getOpOperands()) {
1792+
for (OpOperand &yieldOperand : newWarpOpYield->getOpOperands()) {
1793+
// Yielded value is not a result of the forOp.
17931794
if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
17941795
nonForYieldedValues.push_back(yieldOperand.get());
1795-
// nonForYieldedTypes.push_back(
1796-
// warpOp.getResult(yieldOperand.getOperandNumber()).getType());
17971796
nonForResultIndices.push_back(yieldOperand.getOperandNumber());
17981797
continue;
17991798
}
18001799
OpResult forResult = cast<OpResult>(yieldOperand.get());
18011800
forResultMapping[yieldOperand.getOperandNumber()] =
18021801
forResult.getResultNumber();
1803-
// forResultToWarpResultMapping[forResult.getResultNumber()] =
1804-
// yieldOperand.getOperandNumber();
1805-
// yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
18061802
}
18071803

1808-
// llvm::errs() << "non for yielded values size: "
1809-
// << nonForYieldedValues.size() << "\n";
1810-
1811-
// llvm::errs() << "escpaing values size: " << escapingValues.size() <<
1812-
// "\n";
1813-
SmallVector<Value> yieldedValuesFromWarpOp;
1814-
SmallVector<Type> yieldedTypesFromWarpOp;
1815-
// All init args of the forOp are yielded from the original warp op.
1804+
// Newly created warp op will yield values in following order:
1805+
// 1. All init args of the forOp.
1806+
// 2. All escaping values.
1807+
// 3. All non-for yielded values.
1808+
SmallVector<Value> newWarpOpYieldValues;
1809+
SmallVector<Type> newWarpOpDistTypes;
18161810
for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
1817-
yieldedValuesFromWarpOp.push_back(initArg);
1818-
// find distributed type for the init arg.
1811+
newWarpOpYieldValues.push_back(initArg);
1812+
// Compute the distributed type for this init arg.
18191813
Type distType = initArg.getType();
18201814
if (auto vecType = dyn_cast<VectorType>(distType)) {
1821-
// if (forResultToWarpResultMapping.contains(i)) {
1822-
// // If the init arg is yielded from the warp op, we need to compute
1823-
// the
1824-
// // distributed type.
1825-
// distType =
1826-
// warpOp.getResult(forResultToWarpResultMapping[i]).getType();
1827-
// } else {
18281815
AffineMap map = distributionMapFn(initArg);
18291816
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1830-
// }
18311817
}
1832-
// llvm::errs() << "distributed type: " << distType << "\n";
1833-
yieldedTypesFromWarpOp.push_back(distType);
1818+
newWarpOpDistTypes.push_back(distType);
18341819
}
1835-
// All escaping values are yielded from the original warp op.
1836-
yieldedValuesFromWarpOp.insert(yieldedValuesFromWarpOp.end(),
1837-
escapingValues.begin(),
1838-
escapingValues.end());
1839-
yieldedTypesFromWarpOp.insert(yieldedTypesFromWarpOp.end(),
1840-
distTypes.begin(), distTypes.end());
1841-
1820+
// Insert escaping values and their distributed types.
1821+
newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
1822+
escapingValues.begin(), escapingValues.end());
1823+
newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
1824+
escapingValuedistTypes.begin(),
1825+
escapingValuedistTypes.end());
1826+
// Next, we insert all non-for yielded values and their distributed types.
1827+
// We also create a mapping between the non-for yielded value index and the
1828+
// corresponding new warp op yield value index (needed to update users
1829+
// later).
1830+
DenseMap<unsigned, unsigned> warpResultMapping;
18421831
for (auto [i, v] : llvm::enumerate(nonForYieldedValues)) {
1843-
warpResultMapping[nonForResultIndices[i]] =
1844-
yieldedValuesFromWarpOp.size();
1845-
yieldedValuesFromWarpOp.push_back(v);
1846-
yieldedTypesFromWarpOp.push_back(
1832+
warpResultMapping[nonForResultIndices[i]] = newWarpOpYieldValues.size();
1833+
newWarpOpYieldValues.push_back(v);
1834+
newWarpOpDistTypes.push_back(
18471835
warpOp.getResult(nonForResultIndices[i]).getType());
18481836
}
1849-
1850-
// SmallVector<size_t> newRetIndices;
1837+
// Create the new warp op with the updated yield values and types.
18511838
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
1852-
rewriter, warpOp, yieldedValuesFromWarpOp, yieldedTypesFromWarpOp);
1853-
yield = cast<gpu::YieldOp>(
1839+
rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
1840+
newWarpOpYield = cast<gpu::YieldOp>(
18541841
newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
18551842

1856-
// newWarpOp->print(llvm::outs());
1857-
// llvm::outs() << "\n";
1858-
1859-
SmallVector<Value> newOperands;
1860-
// Collect the new init args coming from the new warp op.
1861-
for (size_t i = 0; i < forOp.getInitArgs().size(); ++i)
1862-
newOperands.push_back(newWarpOp.getResult(i));
1863-
// for (OpOperand &yieldOperand : yield->getOpOperands()) {
1864-
// if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
1865-
// continue;
1866-
// OpResult forResult = cast<OpResult>(yieldOperand.get());
1867-
// resultIdx.push_back(forResult.getResultNumber());
1868-
// yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
1869-
// }
1843+
// Next, we create a new for op with the init args yielded by the new
1844+
// warp op.
1845+
unsigned escapingValuesStartIdx =
1846+
forOp.getInitArgs().size(); // ForOp init args are positioned before
1847+
// escaping values in the new warp op.
1848+
SmallVector<Value> newForOpOperands;
1849+
for (size_t i = 0; i < escapingValuesStartIdx; ++i)
1850+
newForOpOperands.push_back(newWarpOp.getResult(i));
18701851

1852+
// Create a new for op outside the new warp op region.
18711853
OpBuilder::InsertionGuard g(rewriter);
18721854
rewriter.setInsertionPointAfter(newWarpOp);
1873-
1874-
// Create a new for op outside the region with a WarpExecuteOnLane0Op
1875-
// region inside.
18761855
auto newForOp = rewriter.create<scf::ForOp>(
18771856
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1878-
forOp.getStep(), newOperands);
1857+
forOp.getStep(), newForOpOperands);
1858+
// Next, we insert a new warp op (called inner warp op) inside the
1859+
// newly created for op. This warp op will contain all ops that were
1860+
// contained within the original for op body.
18791861
rewriter.setInsertionPointToStart(newForOp.getBody());
18801862

1881-
SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(),
1882-
newForOp.getRegionIterArgs().end());
1883-
SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
1884-
forOp.getResultTypes().end());
1863+
SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(),
1864+
newForOp.getRegionIterArgs().end());
1865+
SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(),
1866+
forOp.getResultTypes().end());
1867+
// Escaping values are forwarded to the inner warp op as its (additional)
1868+
// arguments. We keep track of the mapping between these values and their
1869+
// argument index in the inner warp op (to replcace uses later).
18851870
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1886-
// llvm::errs() << "setting arg index mapping\n";
1887-
unsigned escapingValuesStartIdx = forOp.getInitArgs().size();
18881871
for (size_t i = escapingValuesStartIdx;
18891872
i < escapingValuesStartIdx + escapingValues.size(); ++i) {
1890-
warpInput.push_back(newWarpOp.getResult(i));
1873+
innerWarpInput.push_back(newWarpOp.getResult(i));
18911874
argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
1892-
warpInputType.size();
1893-
warpInputType.push_back(inputTypes[i - escapingValuesStartIdx]);
1875+
innerWarpInputType.size();
1876+
innerWarpInputType.push_back(
1877+
escapingValueInputTypes[i - escapingValuesStartIdx]);
18941878
}
1895-
// for (auto [i, r] : llvm::enumerate(
1896-
// newWarpOp.getResults().drop_front(forOp.getInitArgs().size())))
1897-
// {
1898-
// warpInput.push_back(r);
1899-
// argIndexMapping[escapingValues[i]] = warpInputType.size();
1900-
// warpInputType.push_back(inputTypes[i]);
1901-
// }
1902-
// llvm::errs() << "go here\n";
1879+
// Create the inner warp op with the new input values and types.
19031880
auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
19041881
newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
1905-
newWarpOp.getWarpSize(), warpInput, warpInputType);
1906-
// newForOp->getParentOp()->print(llvm::outs());
1907-
// llvm::outs() << "\n";
1882+
newWarpOp.getWarpSize(), innerWarpInput, innerWarpInputType);
19081883

1884+
// Inline the for op body into the inner warp op body.
19091885
SmallVector<Value> argMapping;
19101886
argMapping.push_back(newForOp.getInductionVar());
1911-
for (Value args : innerWarp.getBody()->getArguments()) {
1887+
for (Value args : innerWarp.getBody()->getArguments())
19121888
argMapping.push_back(args);
1913-
}
1914-
auto forOpCopy = cast<scf::ForOp>(rewriter.clone(*forOp.getOperation()));
1915-
argMapping.resize(forOpCopy.getBody()->getNumArguments());
1889+
1890+
argMapping.resize(forOp.getBody()->getNumArguments());
19161891
SmallVector<Value> yieldOperands;
1917-
for (Value operand : forOpCopy.getBody()->getTerminator()->getOperands())
1892+
for (Value operand : forOp.getBody()->getTerminator()->getOperands())
19181893
yieldOperands.push_back(operand);
19191894

1920-
rewriter.eraseOp(forOpCopy.getBody()->getTerminator());
1921-
rewriter.mergeBlocks(forOpCopy.getBody(), innerWarp.getBody(), argMapping);
1895+
rewriter.eraseOp(forOp.getBody()->getTerminator());
1896+
rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
1897+
1898+
// Insert a gpu yieldOp at the end of the inner warp op body that yields
1899+
// original forOp results.
19221900
rewriter.setInsertionPointToEnd(innerWarp.getBody());
19231901
rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
19241902
rewriter.setInsertionPointAfter(innerWarp);
1903+
// Insert a scf.yield op at the end of the new for op body that yields
1904+
// the inner warp op results.
19251905
if (!innerWarp.getResults().empty())
1926-
rewriter.create<scf::YieldOp>(forOpCopy.getLoc(), innerWarp.getResults());
1927-
// forOpCopy->getParentOp()->getParentOp()->print(llvm::outs());
1928-
// llvm::outs() << "\n";
1929-
// llvm::errs() << "erasing for op\n";
1930-
1931-
rewriter.eraseOp(forOpCopy);
1932-
// Replace the warpOp result coming from the original ForOp.
1933-
// print resultIdx for debugging.
1934-
// llvm::errs() << "resultIdx: ";
1935-
// for (auto idx : resultIdx)
1936-
// llvm::errs() << idx << " ";
1937-
// llvm::errs() << "\n";
1938-
for (auto [origIdx, newIdx] : forResultMapping) {
1906+
rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
1907+
1908+
// Update the users of original warp op results that were coming from the
1909+
// original forOp to the corresponding new forOp result.
1910+
for (auto [origIdx, newIdx] : forResultMapping)
19391911
rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
19401912
newForOp.getResult(newIdx), newForOp);
1941-
// newForOp->setOperand(res.index() + 3,
1942-
// newWarpOp.getResult(res.value()));
1943-
}
1944-
1945-
for (auto [origIdx, newIdx] : warpResultMapping) {
1913+
// Similarly, update any users of the warp op results that were not
1914+
// results of the forOp.
1915+
for (auto [origIdx, newIdx] : warpResultMapping)
19461916
rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
19471917
newWarpOp.getResult(newIdx));
1948-
// newForOp->setOperand(res.index() + 3,
1949-
// newWarpOp.getResult(res.value()));
1950-
}
1918+
// Remove the original warp op and for op, they should not have any uses
1919+
// at this point.
19511920
rewriter.eraseOp(forOp);
19521921
rewriter.eraseOp(warpOp);
1922+
// Update any users of escaping values that were forwarded to the
1923+
// inner warp op. These values are now arguments of the inner warp op.
19531924
newForOp.walk([&](Operation *op) {
19541925
for (OpOperand &operand : op->getOpOperands()) {
19551926
auto it = argIndexMapping.find(operand.get());
@@ -1958,8 +1929,6 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
19581929
operand.set(innerWarp.getBodyRegion().getArgument(it->second));
19591930
}
19601931
});
1961-
// newForOp->getParentOp()->print(llvm::outs());
1962-
// llvm::outs() << "\n";
19631932

19641933
// Finally, hoist out any now uniform code from the inner warp op.
19651934
mlir::vector::moveScalarUniformCode(innerWarp);

0 commit comments

Comments
 (0)