Skip to content

Commit 1270a25

Browse files
authored
[python][gpu] GPU code refactoring (#207)
1 parent 1637cca commit 1270a25

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

mlir/lib/Conversion/gpu_runtime_to_llvm.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,8 +498,10 @@ class ConvertGpuKernelLaunchPattern
498498
mlir::MemRefDescriptor desc(kernelParams[i]);
499499
if (memrefType.getMemorySpace() == localMemStorageClass) {
500500
auto rank = static_cast<unsigned>(memrefType.getRank());
501+
auto typeSize = std::max(memrefType.getElementTypeBitWidth(), 8u) / 8;
501502
mlir::Value size = rewriter.create<mlir::LLVM::ConstantOp>(
502-
loc, llvmIndexType, rewriter.getIntegerAttr(llvmIndexType, 0));
503+
loc, llvmIndexType,
504+
rewriter.getIntegerAttr(llvmIndexType, typeSize));
503505
for (auto i : llvm::seq(0u, rank)) {
504506
auto dim = desc.size(rewriter, loc, i);
505507
size = rewriter.create<mlir::LLVM::MulOp>(loc, llvmIndexType, size,

numba_dpcomp/numba_dpcomp/mlir_compiler/lib/pipelines/lower_to_gpu.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,6 +1167,9 @@ class ConvertBarrierOps : public mlir::OpRewritePattern<mlir::func::CallOp> {
11671167
mlir::LogicalResult
11681168
matchAndRewrite(mlir::func::CallOp op,
11691169
mlir::PatternRewriter &rewriter) const override {
1170+
if (!op->getParentOfType<mlir::gpu::LaunchOp>())
1171+
return mlir::failure();
1172+
11701173
auto operands = op.operands();
11711174
if (operands.size() != 1)
11721175
return mlir::failure();
@@ -1197,6 +1200,10 @@ class ConvertBarrierOps : public mlir::OpRewritePattern<mlir::func::CallOp> {
11971200
}
11981201
};
11991202

1203+
struct LowerGpuBuiltins2Pass
1204+
: public plier::RewriteWrapperPass<LowerGpuBuiltins2Pass, void, void,
1205+
ConvertBarrierOps> {};
1206+
12001207
class ConvertArrayAllocOps : public mlir::OpRewritePattern<mlir::func::CallOp> {
12011208
public:
12021209
using OpRewritePattern::OpRewritePattern;
@@ -1279,9 +1286,8 @@ class ConvertArrayAllocOps : public mlir::OpRewritePattern<mlir::func::CallOp> {
12791286
}
12801287
};
12811288

1282-
struct LowerGpuBuiltins2Pass
1283-
: public plier::RewriteWrapperPass<LowerGpuBuiltins2Pass, void, void,
1284-
ConvertBarrierOps,
1289+
struct LowerGpuBuiltins3Pass
1290+
: public plier::RewriteWrapperPass<LowerGpuBuiltins3Pass, void, void,
12851291
ConvertArrayAllocOps> {};
12861292

12871293
class GpuLaunchSinkOpsPass
@@ -1412,6 +1418,7 @@ static void populateLowerToGPUPipelineLow(mlir::OpPassManager &pm) {
14121418

14131419
commonOptPasses(funcPM);
14141420
funcPM.addPass(std::make_unique<KernelMemrefOpsMovementPass>());
1421+
funcPM.addPass(std::make_unique<LowerGpuBuiltins2Pass>());
14151422
funcPM.addPass(std::make_unique<SinkGpuDimsPass>());
14161423
funcPM.addPass(std::make_unique<GpuLaunchSinkOpsPass>());
14171424
pm.addPass(mlir::createGpuKernelOutliningPass());
@@ -1425,7 +1432,7 @@ static void populateLowerToGPUPipelineLow(mlir::OpPassManager &pm) {
14251432
pm.nest<mlir::gpu::GPUModuleOp>().nest<mlir::gpu::GPUFuncOp>();
14261433
gpuFuncPM.addPass(mlir::arith::createArithmeticExpandOpsPass());
14271434
gpuFuncPM.addPass(std::make_unique<FlattenScfPass>());
1428-
gpuFuncPM.addPass(std::make_unique<LowerGpuBuiltins2Pass>());
1435+
gpuFuncPM.addPass(std::make_unique<LowerGpuBuiltins3Pass>());
14291436
commonOptPasses(gpuFuncPM);
14301437
gpuFuncPM.addPass(std::make_unique<AssumeGpuIdRangePass>());
14311438

0 commit comments

Comments
 (0)