@@ -1167,6 +1167,9 @@ class ConvertBarrierOps : public mlir::OpRewritePattern<mlir::func::CallOp> {
1167
1167
mlir::LogicalResult
1168
1168
matchAndRewrite (mlir::func::CallOp op,
1169
1169
mlir::PatternRewriter &rewriter) const override {
1170
+ if (!op->getParentOfType <mlir::gpu::LaunchOp>())
1171
+ return mlir::failure ();
1172
+
1170
1173
auto operands = op.operands ();
1171
1174
if (operands.size () != 1 )
1172
1175
return mlir::failure ();
@@ -1197,6 +1200,10 @@ class ConvertBarrierOps : public mlir::OpRewritePattern<mlir::func::CallOp> {
1197
1200
}
1198
1201
};
1199
1202
1203
+ struct LowerGpuBuiltins2Pass
1204
+ : public plier::RewriteWrapperPass<LowerGpuBuiltins2Pass, void , void ,
1205
+ ConvertBarrierOps> {};
1206
+
1200
1207
class ConvertArrayAllocOps : public mlir ::OpRewritePattern<mlir::func::CallOp> {
1201
1208
public:
1202
1209
using OpRewritePattern::OpRewritePattern;
@@ -1279,9 +1286,8 @@ class ConvertArrayAllocOps : public mlir::OpRewritePattern<mlir::func::CallOp> {
1279
1286
}
1280
1287
};
1281
1288
1282
- struct LowerGpuBuiltins2Pass
1283
- : public plier::RewriteWrapperPass<LowerGpuBuiltins2Pass, void , void ,
1284
- ConvertBarrierOps,
1289
+ struct LowerGpuBuiltins3Pass
1290
+ : public plier::RewriteWrapperPass<LowerGpuBuiltins3Pass, void , void ,
1285
1291
ConvertArrayAllocOps> {};
1286
1292
1287
1293
class GpuLaunchSinkOpsPass
@@ -1412,6 +1418,7 @@ static void populateLowerToGPUPipelineLow(mlir::OpPassManager &pm) {
1412
1418
1413
1419
commonOptPasses (funcPM);
1414
1420
funcPM.addPass (std::make_unique<KernelMemrefOpsMovementPass>());
1421
+ funcPM.addPass (std::make_unique<LowerGpuBuiltins2Pass>());
1415
1422
funcPM.addPass (std::make_unique<SinkGpuDimsPass>());
1416
1423
funcPM.addPass (std::make_unique<GpuLaunchSinkOpsPass>());
1417
1424
pm.addPass (mlir::createGpuKernelOutliningPass ());
@@ -1425,7 +1432,7 @@ static void populateLowerToGPUPipelineLow(mlir::OpPassManager &pm) {
1425
1432
pm.nest <mlir::gpu::GPUModuleOp>().nest <mlir::gpu::GPUFuncOp>();
1426
1433
gpuFuncPM.addPass (mlir::arith::createArithmeticExpandOpsPass ());
1427
1434
gpuFuncPM.addPass (std::make_unique<FlattenScfPass>());
1428
- gpuFuncPM.addPass (std::make_unique<LowerGpuBuiltins2Pass >());
1435
+ gpuFuncPM.addPass (std::make_unique<LowerGpuBuiltins3Pass >());
1429
1436
commonOptPasses (gpuFuncPM);
1430
1437
gpuFuncPM.addPass (std::make_unique<AssumeGpuIdRangePass>());
1431
1438
0 commit comments