@@ -1341,6 +1341,57 @@ struct GPULowerDefaultLocalSize
1341
1341
}
1342
1342
};
1343
1343
1344
+ struct FlattenScfIf : public mlir ::OpRewritePattern<mlir::scf::IfOp> {
1345
+ using OpRewritePattern::OpRewritePattern;
1346
+
1347
+ mlir::LogicalResult
1348
+ matchAndRewrite (mlir::scf::IfOp op,
1349
+ mlir::PatternRewriter &rewriter) const override {
1350
+ if (op->getNumResults () == 0 )
1351
+ return mlir::failure ();
1352
+
1353
+ auto arithDialect =
1354
+ getContext ()->getOrLoadDialect <mlir::arith::ArithmeticDialect>();
1355
+ auto canFlatten = [&](mlir::Operation *op) {
1356
+ return op->getDialect () == arithDialect;
1357
+ };
1358
+
1359
+ auto &trueBody = op.getThenRegion ().front ();
1360
+ auto &falseBody = op.getElseRegion ().front ();
1361
+ for (auto *block : {&trueBody, &falseBody})
1362
+ for (auto &op : block->without_terminator ())
1363
+ if (!canFlatten (&op))
1364
+ return mlir::failure ();
1365
+
1366
+ mlir::BlockAndValueMapping mapper;
1367
+ for (auto *block : {&trueBody, &falseBody})
1368
+ for (auto &op : block->without_terminator ())
1369
+ rewriter.clone (op, mapper);
1370
+
1371
+ auto trueYield = mlir::cast<mlir::scf::YieldOp>(trueBody.getTerminator ());
1372
+ auto falseYield = mlir::cast<mlir::scf::YieldOp>(falseBody.getTerminator ());
1373
+
1374
+ llvm::SmallVector<mlir::Value> results;
1375
+ results.reserve (op->getNumResults ());
1376
+
1377
+ auto loc = op->getLoc ();
1378
+ auto cond = op.getCondition ();
1379
+ for (auto it : llvm::zip (trueYield.getResults (), falseYield.getResults ())) {
1380
+ auto trueVal = mapper.lookupOrDefault (std::get<0 >(it));
1381
+ auto falseVal = mapper.lookupOrDefault (std::get<1 >(it));
1382
+ auto res =
1383
+ rewriter.create <mlir::arith::SelectOp>(loc, cond, trueVal, falseVal);
1384
+ results.emplace_back (res);
1385
+ }
1386
+
1387
+ rewriter.replaceOp (op, results);
1388
+ return mlir::success ();
1389
+ }
1390
+ };
1391
+
1392
+ struct FlattenScfPass : public plier ::RewriteWrapperPass<FlattenScfPass, void ,
1393
+ void , FlattenScfIf> {};
1394
+
1344
1395
template <typename Op, typename F>
1345
1396
static mlir::LogicalResult createGpuKernelLoad (mlir::PatternRewriter &builder,
1346
1397
Op &&op, F &&func) {
@@ -2972,11 +3023,15 @@ static void populateLowerToGPUPipelineLow(mlir::OpPassManager &pm) {
2972
3023
pm.addPass (mlir::createSymbolDCEPass ());
2973
3024
2974
3025
pm.addNestedPass <mlir::FuncOp>(std::make_unique<GPULowerDefaultLocalSize>());
2975
- pm.nest <mlir::gpu::GPUModuleOp>().addNestedPass <mlir::gpu::GPUFuncOp>(
2976
- mlir::arith::createArithmeticExpandOpsPass ());
2977
3026
pm.addNestedPass <mlir::FuncOp>(mlir::createCanonicalizerPass ());
2978
3027
pm.addPass (mlir::createSymbolDCEPass ());
2979
- pm.addNestedPass <mlir::gpu::GPUModuleOp>(mlir::createCanonicalizerPass ());
3028
+
3029
+ auto &gpuFuncPM =
3030
+ pm.nest <mlir::gpu::GPUModuleOp>().nest <mlir::gpu::GPUFuncOp>();
3031
+ gpuFuncPM.addPass (mlir::arith::createArithmeticExpandOpsPass ());
3032
+ gpuFuncPM.addPass (std::make_unique<FlattenScfPass>());
3033
+ commonOptPasses (gpuFuncPM);
3034
+
2980
3035
pm.addNestedPass <mlir::gpu::GPUModuleOp>(std::make_unique<AbiAttrsPass>());
2981
3036
pm.addPass (std::make_unique<SetSPIRVCapabilitiesPass>());
2982
3037
pm.addPass (std::make_unique<GPUToSpirvPass>());
0 commit comments