Skip to content

Commit e57e8c3

Browse files
committed
address more comments
1 parent e733600 commit e57e8c3

File tree

2 files changed

+16
-19
lines changed

2 files changed

+16
-19
lines changed

mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- VectorDistributionUtils.h - Distribution Utilities -------*- C++ -*-===//
1+
//===- DistributionUtils.h - Distribution Utilities -------------*- C++ -*-===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -15,13 +15,10 @@
1515
#include "mlir/IR/PatternMatch.h"
1616
#include "mlir/IR/Value.h"
1717

18-
#include <numeric>
19-
#include <utility>
20-
21-
namespace mlir {
22-
namespace gpu {
18+
namespace mlir::gpu {
2319
struct WarpDistributionPattern : OpRewritePattern<WarpExecuteOnLane0Op> {
24-
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
20+
using OpRewritePattern::OpRewritePattern;
21+
using Base = WarpDistributionPattern;
2522

2623
virtual LogicalResult
2724
matchAndRewrite(WarpExecuteOnLane0Op op,
@@ -30,8 +27,9 @@ struct WarpDistributionPattern : OpRewritePattern<WarpExecuteOnLane0Op> {
3027
protected:
3128
/// Return a value yielded by `warpOp` which statifies the filter lamdba
3229
/// condition and is not dead.
33-
OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
34-
const std::function<bool(Operation *)> &fn) const;
30+
OpOperand *
31+
getWarpResult(WarpExecuteOnLane0Op warpOp,
32+
const llvm::function_ref<bool(Operation *)> fn) const;
3533

3634
/// Helper to create a new WarpExecuteOnLane0Op with different signature.
3735
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
@@ -43,7 +41,7 @@ struct WarpDistributionPattern : OpRewritePattern<WarpExecuteOnLane0Op> {
4341
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
4442
RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
4543
ValueRange newYieldedValues, TypeRange newReturnTypes,
46-
llvm::SmallVector<size_t> &indices) const;
44+
SmallVector<size_t> &indices) const;
4745

4846
/// Delinearize the given `laneId` into multiple dimensions, where each
4947
/// dimension's size is determined by `originalShape` and `distributedShape`
@@ -57,7 +55,6 @@ struct WarpDistributionPattern : OpRewritePattern<WarpExecuteOnLane0Op> {
5755
SmallVectorImpl<Value> &delinearizedIds) const;
5856
};
5957

60-
} // namespace gpu
61-
} // namespace mlir
58+
} // namespace mlir::gpu
6259

6360
#endif // MLIR_DIALECT_GPU_TRANSFORMS_DISTRIBUTIONUTILS_H_

mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,22 +51,22 @@ WarpExecuteOnLane0Op
5151
WarpDistributionPattern::moveRegionToNewWarpOpAndAppendReturns(
5252
RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
5353
ValueRange newYieldedValues, TypeRange newReturnTypes,
54-
llvm::SmallVector<size_t> &indices) const {
54+
SmallVector<size_t> &indices) const {
5555
SmallVector<Type> types(warpOp.getResultTypes().begin(),
5656
warpOp.getResultTypes().end());
5757
auto yield = cast<gpu::YieldOp>(
5858
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
5959
llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
6060
yield.getOperands().end());
61-
for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
62-
if (yieldValues.insert(std::get<0>(newRet))) {
63-
types.push_back(std::get<1>(newRet));
61+
for (auto [value, type] : llvm::zip_equal(newYieldedValues, newReturnTypes)) {
62+
if (yieldValues.insert(value)) {
63+
types.push_back(type);
6464
indices.push_back(yieldValues.size() - 1);
6565
} else {
6666
// If the value already exit the region don't create a new output.
6767
for (auto [idx, yieldOperand] :
6868
llvm::enumerate(yieldValues.getArrayRef())) {
69-
if (yieldOperand == std::get<0>(newRet)) {
69+
if (yieldOperand == value) {
7070
indices.push_back(idx);
7171
break;
7272
}
@@ -83,7 +83,7 @@ WarpDistributionPattern::moveRegionToNewWarpOpAndAppendReturns(
8383

8484
OpOperand *WarpDistributionPattern::getWarpResult(
8585
WarpExecuteOnLane0Op warpOp,
86-
const std::function<bool(Operation *)> &fn) const {
86+
const llvm::function_ref<bool(Operation *)> fn) const {
8787
auto yield = cast<gpu::YieldOp>(
8888
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
8989
for (OpOperand &yieldOperand : yield->getOpOperands()) {
@@ -94,7 +94,7 @@ OpOperand *WarpDistributionPattern::getWarpResult(
9494
return &yieldOperand;
9595
}
9696
}
97-
return {};
97+
return nullptr;
9898
}
9999

100100
bool WarpDistributionPattern::delinearizeLaneId(

0 commit comments

Comments
 (0)