Skip to content

Commit 164e9d6

Browse files
committed
address comments
1 parent 8ecece4 commit 164e9d6

File tree

1 file changed

+44
-44
lines changed

1 file changed

+44
-44
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1751,13 +1751,13 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17511751
PatternRewriter &rewriter) const override {
17521752
auto newWarpOpYield = cast<gpu::YieldOp>(
17531753
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1754-
// Only pick up forOp if it is the last op in the region.
1754+
// Only pick up `ForOp` if it is the last op in the region.
17551755
Operation *lastNode = newWarpOpYield->getPrevNode();
17561756
auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
17571757
if (!forOp)
17581758
return failure();
1759-
// Collect Values that come from the warp op but are outside the forOp.
1760-
// Those Value needs to be returned by the new warp op.
1759+
// Collect Values that come from the `WarpOp` but are outside the `ForOp`.
1760+
// Those Values need to be returned by the new warp op.
17611761
llvm::SmallSetVector<Value, 32> escapingValues;
17621762
SmallVector<Type> escapingValueInputTypes;
17631763
SmallVector<Type> escapingValuedistTypes;
@@ -1779,16 +1779,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17791779

17801780
if (llvm::is_contained(escapingValuedistTypes, Type{}))
17811781
return failure();
1782-
// Warp op can yield two types of values:
1783-
// 1. Values that are not results of the forOp:
1784-
// These values must also be yielded by the new warp op. Also, we need to
1785-
// record the index mapping for these values to replace them later.
1786-
// 2. Values that are results of the forOp:
1787-
// In this case, we record the index mapping between the warp op result
1788-
// index and matching forOp result index.
1782+
// `WarpOp` can yield two types of values:
1783+
// 1. Values that are not results of the `ForOp`:
1784+
// These values must also be yielded by the new `WarpOp`. Also, we need
1785+
// to record the index mapping for these values to replace them later.
1786+
// 2. Values that are results of the `ForOp`:
1787+
// In this case, we record the index mapping between the `WarpOp` result
1788+
// index and matching `ForOp` result index.
17891789
SmallVector<Value> nonForYieldedValues;
17901790
SmallVector<unsigned> nonForResultIndices;
1791-
DenseMap<unsigned, unsigned> forResultMapping;
1791+
llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
17921792
for (OpOperand &yieldOperand : newWarpOpYield->getOpOperands()) {
17931793
// Yielded value is not a result of the forOp.
17941794
if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
@@ -1801,10 +1801,10 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18011801
forResult.getResultNumber();
18021802
}
18031803

1804-
// Newly created warp op will yield values in following order:
1805-
// 1. All init args of the forOp.
1804+
// Newly created `WarpOp` will yield values in following order:
1805+
// 1. All init args of the `ForOp`.
18061806
// 2. All escaping values.
1807-
// 3. All non-for yielded values.
1807+
// 3. All non-`ForOp` yielded values.
18081808
SmallVector<Value> newWarpOpYieldValues;
18091809
SmallVector<Type> newWarpOpDistTypes;
18101810
for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
@@ -1823,50 +1823,50 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18231823
newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
18241824
escapingValuedistTypes.begin(),
18251825
escapingValuedistTypes.end());
1826-
// Next, we insert all non-for yielded values and their distributed types.
1827-
// We also create a mapping between the non-for yielded value index and the
1828-
// corresponding new warp op yield value index (needed to update users
1829-
// later).
1830-
DenseMap<unsigned, unsigned> warpResultMapping;
1826+
// Next, we insert all non-`ForOp` yielded values and their distributed
1827+
// types. We also create a mapping between the non-`ForOp` yielded value
1828+
// index and the corresponding new `WarpOp` yield value index (needed to
1829+
// update users later).
1830+
llvm::SmallDenseMap<unsigned, unsigned> warpResultMapping;
18311831
for (auto [i, v] : llvm::enumerate(nonForYieldedValues)) {
18321832
warpResultMapping[nonForResultIndices[i]] = newWarpOpYieldValues.size();
18331833
newWarpOpYieldValues.push_back(v);
18341834
newWarpOpDistTypes.push_back(
18351835
warpOp.getResult(nonForResultIndices[i]).getType());
18361836
}
1837-
// Create the new warp op with the updated yield values and types.
1837+
// Create the new `WarpOp` with the updated yield values and types.
18381838
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
18391839
rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
18401840
newWarpOpYield = cast<gpu::YieldOp>(
18411841
newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
18421842

1843-
// Next, we create a new for op with the init args yielded by the new
1844-
// warp op.
1845-
unsigned escapingValuesStartIdx =
1846-
forOp.getInitArgs().size(); // ForOp init args are positioned before
1847-
// escaping values in the new warp op.
1843+
// Next, we create a new `ForOp` with the init args yielded by the new
1844+
// `WarpOp`.
1845+
const unsigned escapingValuesStartIdx =
1846+
forOp.getInitArgs().size(); // `ForOp` init args are positioned before
1847+
// escaping values in the new `WarpOp`.
18481848
SmallVector<Value> newForOpOperands;
18491849
for (size_t i = 0; i < escapingValuesStartIdx; ++i)
18501850
newForOpOperands.push_back(newWarpOp.getResult(i));
18511851

1852-
// Create a new for op outside the new warp op region.
1852+
// Create a new `ForOp` outside the new `WarpOp` region.
18531853
OpBuilder::InsertionGuard g(rewriter);
18541854
rewriter.setInsertionPointAfter(newWarpOp);
18551855
auto newForOp = rewriter.create<scf::ForOp>(
18561856
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
18571857
forOp.getStep(), newForOpOperands);
1858-
// Next, we insert a new warp op (called inner warp op) inside the
1859-
// newly created for op. This warp op will contain all ops that were
1860-
// contained within the original for op body.
1858+
// Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
1859+
// newly created `ForOp`. This `WarpOp` will contain all ops that were
1860+
// contained within the original `ForOp` body.
18611861
rewriter.setInsertionPointToStart(newForOp.getBody());
18621862

18631863
SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(),
18641864
newForOp.getRegionIterArgs().end());
18651865
SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(),
18661866
forOp.getResultTypes().end());
1867-
// Escaping values are forwarded to the inner warp op as its (additional)
1867+
// Escaping values are forwarded to the inner `WarpOp` as its (additional)
18681868
// arguments. We keep track of the mapping between these values and their
1869-
// argument index in the inner warp op (to replcace uses later).
1869+
// argument index in the inner `WarpOp` (to replace users later).
18701870
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
18711871
for (size_t i = escapingValuesStartIdx;
18721872
i < escapingValuesStartIdx + escapingValues.size(); ++i) {
@@ -1876,12 +1876,12 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18761876
innerWarpInputType.push_back(
18771877
escapingValueInputTypes[i - escapingValuesStartIdx]);
18781878
}
1879-
// Create the inner warp op with the new input values and types.
1879+
// Create the inner `WarpOp` with the new input values and types.
18801880
auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
18811881
newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
18821882
newWarpOp.getWarpSize(), innerWarpInput, innerWarpInputType);
18831883

1884-
// Inline the for op body into the inner warp op body.
1884+
// Inline the `ForOp` body into the inner `WarpOp` body.
18851885
SmallVector<Value> argMapping;
18861886
argMapping.push_back(newForOp.getInductionVar());
18871887
for (Value args : innerWarp.getBody()->getArguments())
@@ -1895,32 +1895,32 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18951895
rewriter.eraseOp(forOp.getBody()->getTerminator());
18961896
rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
18971897

1898-
// Insert a gpu yieldOp at the end of the inner warp op body that yields
1899-
// original forOp results.
1898+
// Insert a gpu `YieldOp` at the end of the inner `WarpOp` body that yields
1899+
// original `ForOp` results.
19001900
rewriter.setInsertionPointToEnd(innerWarp.getBody());
19011901
rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
19021902
rewriter.setInsertionPointAfter(innerWarp);
1903-
// Insert a scf.yield op at the end of the new for op body that yields
1904-
// the inner warp op results.
1903+
// Insert a scf.yield op at the end of the new `ForOp` body that yields
1904+
// the inner `WarpOp` results.
19051905
if (!innerWarp.getResults().empty())
19061906
rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
19071907

1908-
// Update the users of original warp op results that were coming from the
1909-
// original forOp to the corresponding new forOp result.
1908+
// Update the users of original `WarpOp` results that were coming from the
1909+
// original `ForOp` to the corresponding new `ForOp` result.
19101910
for (auto [origIdx, newIdx] : forResultMapping)
19111911
rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
19121912
newForOp.getResult(newIdx), newForOp);
1913-
// Similarly, update any users of the warp op results that were not
1914-
// results of the forOp.
1913+
// Similarly, update any users of the `WarpOp` results that were not
1914+
// results of the `ForOp`.
19151915
for (auto [origIdx, newIdx] : warpResultMapping)
19161916
rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
19171917
newWarpOp.getResult(newIdx));
1918-
// Remove the original warp op and for op, they should not have any uses
1918+
// Remove the original `WarpOp` and `ForOp`, they should not have any uses
19191919
// at this point.
19201920
rewriter.eraseOp(forOp);
19211921
rewriter.eraseOp(warpOp);
19221922
// Update any users of escaping values that were forwarded to the
1923-
// inner warp op. These values are now arguments of the inner warp op.
1923+
// inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
19241924
newForOp.walk([&](Operation *op) {
19251925
for (OpOperand &operand : op->getOpOperands()) {
19261926
auto it = argIndexMapping.find(operand.get());
@@ -1930,7 +1930,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
19301930
}
19311931
});
19321932

1933-
// Finally, hoist out any now uniform code from the inner warp op.
1933+
// Finally, hoist out any now uniform code from the inner `WarpOp`.
19341934
mlir::vector::moveScalarUniformCode(innerWarp);
19351935
return success();
19361936
}

0 commit comments

Comments
 (0)