Skip to content

Commit f2d1eb0

Browse files
authored
GPU lowering refactoring (#157)
1 parent 4ada912 commit f2d1eb0

File tree

6 files changed

+44
-14
lines changed

6 files changed

+44
-14
lines changed

mlir/include/mlir-extensions/transforms/common_opts.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@
1414

1515
#pragma once
1616

17+
#include <memory>
18+
1719
namespace mlir {
1820
class RewritePatternSet;
1921
class MLIRContext;
22+
class Pass;
2023
} // namespace mlir
2124

2225
namespace plier {
@@ -25,4 +28,6 @@ void populateCanonicalizationPatterns(mlir::MLIRContext &context,
2528

2629
void populateCommonOptsPatterns(mlir::MLIRContext &context,
2730
mlir::RewritePatternSet &patterns);
31+
32+
std::unique_ptr<mlir::Pass> createCommonOptsPass();
2833
} // namespace plier

mlir/include/mlir-extensions/transforms/index_type_propagation.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,6 @@ class MLIRContext;
2020
} // namespace mlir
2121

2222
namespace plier {
23-
void populate_index_propagate_patterns(mlir::MLIRContext &context,
24-
mlir::RewritePatternSet &patterns);
23+
void populateIndexPropagatePatterns(mlir::MLIRContext &context,
24+
mlir::RewritePatternSet &patterns);
2525
}

mlir/lib/transforms/common_opts.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
#include "mlir-extensions/transforms/index_type_propagation.hpp"
2020
#include "mlir-extensions/transforms/loop_rewrites.hpp"
2121
#include "mlir-extensions/transforms/memory_rewrites.hpp"
22+
#include "mlir/Pass/Pass.h"
23+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2224

2325
#include <mlir/Dialect/Math/IR/Math.h>
2426
#include <mlir/Dialect/MemRef/IR/MemRef.h>
@@ -129,6 +131,20 @@ struct PowSimplify : public mlir::OpRewritePattern<mlir::math::PowFOp> {
129131
return mlir::failure();
130132
}
131133
};
134+
135+
struct CommonOptsPass
136+
: public mlir::PassWrapper<CommonOptsPass, mlir::OperationPass<void>> {
137+
138+
void runOnOperation() override {
139+
auto *ctx = &getContext();
140+
mlir::RewritePatternSet patterns(ctx);
141+
142+
plier::populateCommonOptsPatterns(*ctx, patterns);
143+
144+
(void)mlir::applyPatternsAndFoldGreedily(getOperation(),
145+
std::move(patterns));
146+
}
147+
};
132148
} // namespace
133149

134150
void plier::populateCanonicalizationPatterns(
@@ -155,5 +171,9 @@ void plier::populateCommonOptsPatterns(mlir::MLIRContext &context,
155171
// clang-format on
156172
>(&context);
157173

158-
plier::populate_index_propagate_patterns(context, patterns);
174+
plier::populateIndexPropagatePatterns(context, patterns);
175+
}
176+
177+
std::unique_ptr<mlir::Pass> plier::createCommonOptsPass() {
178+
return std::make_unique<CommonOptsPass>();
159179
}

mlir/lib/transforms/index_type_propagation.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ struct CmpIndexCastSimplify
147147
};
148148
} // namespace
149149

150-
void plier::populate_index_propagate_patterns(
151-
mlir::MLIRContext &context, mlir::RewritePatternSet &patterns) {
150+
void plier::populateIndexPropagatePatterns(mlir::MLIRContext &context,
151+
mlir::RewritePatternSet &patterns) {
152152
patterns
153153
.insert<CmpIndexCastSimplify, ArithIndexCastSimplify<mlir::arith::SubIOp>,
154154
ArithIndexCastSimplify<mlir::arith::AddIOp>,

numba_dpcomp/numba_dpcomp/mlir_compiler/lib/pipelines/lower_to_gpu.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
#include "mlir-extensions/dialect/plier_util/dialect.hpp"
6868
#include "mlir-extensions/transforms/call_lowering.hpp"
6969
#include "mlir-extensions/transforms/cast_utils.hpp"
70+
#include "mlir-extensions/transforms/common_opts.hpp"
7071
#include "mlir-extensions/transforms/const_utils.hpp"
7172
#include "mlir-extensions/transforms/func_utils.hpp"
7273
#include "mlir-extensions/transforms/pipeline_utils.hpp"
@@ -2986,9 +2987,9 @@ class GpuLaunchSinkOpsPass
29862987
};
29872988

29882989
static void commonOptPasses(mlir::OpPassManager &pm) {
2989-
pm.addPass(mlir::createCanonicalizerPass());
2990+
pm.addPass(plier::createCommonOptsPass());
29902991
pm.addPass(mlir::createCSEPass());
2991-
pm.addPass(mlir::createCanonicalizerPass());
2992+
pm.addPass(plier::createCommonOptsPass());
29922993
}
29932994

29942995
static void populateLowerToGPUPipelineHigh(mlir::OpPassManager &pm) {
@@ -3014,8 +3015,6 @@ static void populateLowerToGPUPipelineLow(mlir::OpPassManager &pm) {
30143015
funcPM.addPass(std::make_unique<UnstrideMemrefsPass>());
30153016
funcPM.addPass(mlir::createLowerAffinePass());
30163017

3017-
// TODO: mlir::gpu::GPUModuleOp pass
3018-
pm.addNestedPass<mlir::FuncOp>(mlir::arith::createArithmeticExpandOpsPass());
30193018
commonOptPasses(funcPM);
30203019
funcPM.addPass(std::make_unique<KernelMemrefOpsMovementPass>());
30213020
funcPM.addPass(std::make_unique<GpuLaunchSinkOpsPass>());

numba_dpcomp/numba_dpcomp/mlir_compiler/lib/pipelines/plier_to_linalg.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,7 @@ computeIndices(mlir::OpBuilder &builder, mlir::Location loc, mlir::Value value,
642642
auto valType = indexVal.getType();
643643

644644
auto len = getDim(dim);
645+
bool ignoreNegativeInd = false;
645646
auto handleNegativeVal = [&](mlir::OpFoldResult val) -> mlir::Value {
646647
mlir::Value idx;
647648
if (auto v = val.dyn_cast<mlir::Value>()) {
@@ -651,11 +652,16 @@ computeIndices(mlir::OpBuilder &builder, mlir::Location loc, mlir::Value value,
651652
auto attrVal = attr.cast<mlir::IntegerAttr>().getValue().getSExtValue();
652653
idx = builder.create<mlir::arith::ConstantIndexOp>(loc, attrVal);
653654
}
654-
auto isNeg = builder.createOrFold<mlir::arith::CmpIOp>(
655-
loc, mlir::arith::CmpIPredicate::slt, idx, zero);
656-
auto negIndex = builder.createOrFold<mlir::arith::AddIOp>(loc, len, idx);
657-
return builder.createOrFold<mlir::arith::SelectOp>(loc, isNeg, negIndex,
658-
idx);
655+
if (ignoreNegativeInd) {
656+
return idx;
657+
} else {
658+
auto isNeg = builder.createOrFold<mlir::arith::CmpIOp>(
659+
loc, mlir::arith::CmpIPredicate::slt, idx, zero);
660+
auto negIndex =
661+
builder.createOrFold<mlir::arith::AddIOp>(loc, len, idx);
662+
return builder.createOrFold<mlir::arith::SelectOp>(loc, isNeg, negIndex,
663+
idx);
664+
}
659665
};
660666

661667
if (auto sliceType = valType.dyn_cast<plier::SliceType>()) {

0 commit comments

Comments
 (0)