Skip to content

Commit 2c2703e

Browse files
committed
address comments
1 parent 683fad8 commit 2c2703e

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1749,10 +1749,10 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17491749
: WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
17501750
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
17511751
PatternRewriter &rewriter) const override {
1752-
auto newWarpOpYield = cast<gpu::YieldOp>(
1752+
auto warpOpYield = cast<gpu::YieldOp>(
17531753
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
17541754
// Only pick up `ForOp` if it is the last op in the region.
1755-
Operation *lastNode = newWarpOpYield->getPrevNode();
1755+
Operation *lastNode = warpOpYield->getPrevNode();
17561756
auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
17571757
if (!forOp)
17581758
return failure();
@@ -1789,7 +1789,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17891789
SmallVector<Value> nonForYieldedValues;
17901790
SmallVector<unsigned> nonForResultIndices;
17911791
llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
1792-
for (OpOperand &yieldOperand : newWarpOpYield->getOpOperands()) {
1792+
for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
17931793
// Yielded value is not a result of the forOp.
17941794
if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
17951795
nonForYieldedValues.push_back(yieldOperand.get());
@@ -1827,18 +1827,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18271827
// types. We also create a mapping between the non-`ForOp` yielded value
18281828
// index and the corresponding new `WarpOp` yield value index (needed to
18291829
// update users later).
1830-
llvm::SmallDenseMap<unsigned, unsigned> warpResultMapping;
1830+
llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping;
18311831
for (auto [i, v] : llvm::enumerate(nonForYieldedValues)) {
1832-
warpResultMapping[nonForResultIndices[i]] = newWarpOpYieldValues.size();
1832+
nonForResultMapping[nonForResultIndices[i]] = newWarpOpYieldValues.size();
18331833
newWarpOpYieldValues.push_back(v);
18341834
newWarpOpDistTypes.push_back(
18351835
warpOp.getResult(nonForResultIndices[i]).getType());
18361836
}
18371837
// Create the new `WarpOp` with the updated yield values and types.
18381838
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
18391839
rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
1840-
newWarpOpYield = cast<gpu::YieldOp>(
1841-
newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
18421840

18431841
// Next, we create a new `ForOp` with the init args yielded by the new
18441842
// `WarpOp`.
@@ -1912,7 +1910,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
19121910
newForOp.getResult(newIdx), newForOp);
19131911
// Similarly, update any users of the `WarpOp` results that were not
19141912
// results of the `ForOp`.
1915-
for (auto [origIdx, newIdx] : warpResultMapping)
1913+
for (auto [origIdx, newIdx] : nonForResultMapping)
19161914
rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
19171915
newWarpOp.getResult(newIdx));
19181916
// Remove the original `WarpOp` and `ForOp`, they should not have any uses

0 commit comments

Comments
 (0)