@@ -1376,6 +1376,136 @@ void MakeStridedLayoutPass::runOnOperation() {
1376
1376
}
1377
1377
}
1378
1378
1379
+ struct ChangeLayoutReturn : public mlir ::OpRewritePattern<mlir::ReturnOp> {
1380
+ using OpRewritePattern::OpRewritePattern;
1381
+
1382
+ mlir::LogicalResult
1383
+ matchAndRewrite (mlir::ReturnOp op,
1384
+ mlir::PatternRewriter &rewriter) const override {
1385
+ if (op.operands ().empty ())
1386
+ return mlir::failure ();
1387
+
1388
+ auto func = op->getParentOfType <mlir::FuncOp>();
1389
+ if (!func || !func.isPrivate () || !llvm::hasSingleElement (func.getBody ()))
1390
+ return mlir::failure ();
1391
+
1392
+ auto mod = func->getParentOfType <mlir::ModuleOp>();
1393
+ assert (mod);
1394
+
1395
+ auto funcUses = mlir::SymbolTable::getSymbolUses (func, mod);
1396
+ if (!funcUses)
1397
+ return mlir::failure ();
1398
+
1399
+ for (auto use : *funcUses)
1400
+ if (!mlir::isa<mlir::CallOp>(use.getUser ()))
1401
+ return mlir::failure ();
1402
+
1403
+ auto loc = op->getLoc ();
1404
+ auto args = op.operands ();
1405
+ auto count = static_cast <unsigned >(args.size ());
1406
+ llvm::SmallVector<mlir::Value> newArgs (args.begin (), args.end ());
1407
+ llvm::SmallVector<int64_t > shape;
1408
+
1409
+ bool changed = false ;
1410
+ for (auto i : llvm::seq (0u , count)) {
1411
+ auto arg = args[i];
1412
+ auto retType = arg.getType ().dyn_cast <mlir::MemRefType>();
1413
+ if (!retType)
1414
+ continue ;
1415
+
1416
+ auto cast = arg.getDefiningOp <mlir::memref::CastOp>();
1417
+ if (!cast)
1418
+ continue ;
1419
+
1420
+ auto src = cast.source ();
1421
+ auto srcType = src.getType ().cast <mlir::MemRefType>();
1422
+ assert (srcType.getElementType () == retType.getElementType ());
1423
+
1424
+ auto srcLayout = srcType.getLayout ();
1425
+ auto srcShape = srcType.getShape ();
1426
+ auto dstShape = retType.getShape ();
1427
+ assert (srcShape.size () == dstShape.size ());
1428
+ auto rank = static_cast <unsigned >(srcShape.size ());
1429
+ shape.resize (rank);
1430
+ for (auto j : llvm::seq (0u , rank)) {
1431
+ if (!mlir::ShapedType::isDynamic (dstShape[j])) {
1432
+ shape[j] = dstShape[j];
1433
+ } else if (!mlir::ShapedType::isDynamic (srcShape[j])) {
1434
+ shape[j] = srcShape[j];
1435
+ } else {
1436
+ shape[j] = mlir::ShapedType::kDynamicSize ;
1437
+ }
1438
+ }
1439
+
1440
+ auto newType = mlir::MemRefType::get (shape, srcType.getElementType (),
1441
+ srcLayout, srcType.getMemorySpace ());
1442
+ if (newType == retType)
1443
+ continue ;
1444
+
1445
+ auto newArg = rewriter.create <mlir::memref::CastOp>(loc, newType, src);
1446
+ newArgs[i] = newArg;
1447
+ changed = true ;
1448
+ }
1449
+
1450
+ if (!changed)
1451
+ return mlir::failure ();
1452
+
1453
+ rewriter.replaceOpWithNewOp <mlir::ReturnOp>(op, newArgs);
1454
+
1455
+ auto newFuncType = [&]() {
1456
+ auto origType = func.getType ();
1457
+ mlir::ValueRange r (newArgs);
1458
+ return mlir::FunctionType::get (getContext (), origType.getInputs (),
1459
+ r.getTypes ());
1460
+ }();
1461
+
1462
+ rewriter.updateRootInPlace (
1463
+ func, [&]() { func.typeAttr (mlir::TypeAttr::get (newFuncType)); });
1464
+
1465
+ llvm::SmallVector<mlir::CallOp> calls;
1466
+ for (auto use : *funcUses) {
1467
+ auto call = mlir::cast<mlir::CallOp>(use.getUser ());
1468
+ calls.emplace_back (call);
1469
+ }
1470
+
1471
+ for (auto call : calls) {
1472
+ rewriter.setInsertionPoint (call);
1473
+ auto callLoc = call->getLoc ();
1474
+ auto oldResults = call.getResults ();
1475
+ auto newResults =
1476
+ rewriter.create <mlir::CallOp>(callLoc, func, call.operands ())
1477
+ .getResults ();
1478
+ newArgs.assign (newResults.begin (), newResults.end ());
1479
+ for (auto i : llvm::seq (0u , count)) {
1480
+ auto oldType = oldResults[i].getType ();
1481
+ auto newType = newArgs[i].getType ();
1482
+ if (oldType != newType)
1483
+ newArgs[i] = rewriter.create <mlir::memref::CastOp>(callLoc, oldType,
1484
+ newArgs[i]);
1485
+ }
1486
+ rewriter.replaceOp (call, newArgs);
1487
+ }
1488
+
1489
+ return mlir::success ();
1490
+ }
1491
+ };
1492
+
1493
+ struct OptimizeStridedLayoutPass
1494
+ : public mlir::PassWrapper<OptimizeStridedLayoutPass,
1495
+ mlir::OperationPass<mlir::ModuleOp>> {
1496
+ void runOnOperation () override {
1497
+ auto *context = &getContext ();
1498
+ mlir::RewritePatternSet patterns (context);
1499
+
1500
+ plier::populateCanonicalizationPatterns (*context, patterns);
1501
+
1502
+ patterns.insert <ChangeLayoutReturn>(context);
1503
+
1504
+ (void )mlir::applyPatternsAndFoldGreedily (getOperation (),
1505
+ std::move (patterns));
1506
+ }
1507
+ };
1508
+
1379
1509
struct FinalizeStridedLayoutPass
1380
1510
: public mlir::PassWrapper<FinalizeStridedLayoutPass,
1381
1511
mlir::OperationPass<>> {
@@ -2716,7 +2846,7 @@ static void populatePlierToLinalgOptPipeline(mlir::OpPassManager &pm) {
2716
2846
2717
2847
pm.addNestedPass <mlir::FuncOp>(std::make_unique<CloneArgsPass>());
2718
2848
pm.addPass (std::make_unique<MakeStridedLayoutPass>());
2719
- pm.addNestedPass <mlir::FuncOp>( mlir::createCanonicalizerPass ());
2849
+ pm.addPass (std::make_unique<OptimizeStridedLayoutPass> ());
2720
2850
pm.addNestedPass <mlir::FuncOp>(std::make_unique<FinalizeStridedLayoutPass>());
2721
2851
pm.addNestedPass <mlir::FuncOp>(
2722
2852
mlir::bufferization::createBufferDeallocationPass ());
0 commit comments