Skip to content

Commit a4c1351

Browse files
authored
Propagate layout change throug return op (#154)
1 parent a10c083 commit a4c1351

File tree

3 files changed

+182
-5
lines changed

3 files changed

+182
-5
lines changed

mlir/lib/dialect/plier_util/dialect.cpp

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,34 @@ struct ChangeLayoutCast : public mlir::OpRewritePattern<mlir::memref::CastOp> {
545545
}
546546
};
547547

548+
struct ChangeLayoutFromCast
549+
: public mlir::OpRewritePattern<plier::ChangeLayoutOp> {
550+
using OpRewritePattern::OpRewritePattern;
551+
552+
mlir::LogicalResult
553+
matchAndRewrite(plier::ChangeLayoutOp op,
554+
mlir::PatternRewriter &rewriter) const override {
555+
auto cast = op.source().getDefiningOp<mlir::memref::CastOp>();
556+
if (!cast)
557+
return mlir::failure();
558+
559+
auto src = cast.source();
560+
auto srcType = src.getType();
561+
auto dstType = op.getType();
562+
if (srcType == dstType) {
563+
rewriter.replaceOp(op, src);
564+
return mlir::success();
565+
}
566+
567+
if (mlir::memref::CastOp::areCastCompatible(srcType, dstType)) {
568+
rewriter.replaceOpWithNewOp<mlir::memref::CastOp>(op, dstType, src);
569+
return mlir::success();
570+
}
571+
572+
return mlir::failure();
573+
}
574+
};
575+
548576
struct ChangeLayoutSignCast : public mlir::OpRewritePattern<plier::SignCastOp> {
549577
using OpRewritePattern::OpRewritePattern;
550578

@@ -953,10 +981,11 @@ void ChangeLayoutOp::getCanonicalizationPatterns(
953981
results.insert<
954982
ChangeLayoutIdentity, ChangeLayoutDim, ChangeLayoutExtractMetadata,
955983
ChangeLayoutClone, PropagateCloneType, ChangeLayoutCast,
956-
ChangeLayoutSignCast, ChangeLayoutLoad, ChangeLayoutStore,
957-
ChangeLayoutSubview, ChangeLayoutLinalgGeneric, ChangeLayoutLinalgFill,
958-
ChangeLayoutIf, ChangeLayout1DReshape, ChangeLayoutSliceGetItem,
959-
ChangeLayoutCopy, ChangeLayoutExpandShape>(context);
984+
ChangeLayoutFromCast, ChangeLayoutSignCast, ChangeLayoutLoad,
985+
ChangeLayoutStore, ChangeLayoutSubview, ChangeLayoutLinalgGeneric,
986+
ChangeLayoutLinalgFill, ChangeLayoutIf, ChangeLayout1DReshape,
987+
ChangeLayoutSliceGetItem, ChangeLayoutCopy, ChangeLayoutExpandShape>(
988+
context);
960989
}
961990

962991
static mlir::Value propagateCasts(mlir::Value val, mlir::Type thisType);

numba_dpcomp/numba_dpcomp/mlir/tests/test_numpy.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,24 @@ def py_func(a):
833833
ir = get_print_buffer()
834834
assert ir.count('affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>') == 1, ir
835835

836+
def test_contigious_layout_return():
837+
def py_func1():
838+
return np.ones((2,3), np.float32).T
839+
840+
jit_func1 = njit(py_func1)
841+
842+
def py_func2(a):
843+
return a
844+
845+
jit_func2 = njit(py_func2)
846+
847+
def py_func3():
848+
a = jit_func1()
849+
return jit_func2(a)
850+
851+
jit_func3 = njit(py_func3)
852+
853+
assert_equal(py_func3(), jit_func3())
836854

837855
@parametrize_function_variants("a", [
838856
# 'np.array(1)', TODO zero rank arrays

numba_dpcomp/numba_dpcomp/mlir_compiler/lib/pipelines/plier_to_linalg.cpp

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1376,6 +1376,136 @@ void MakeStridedLayoutPass::runOnOperation() {
13761376
}
13771377
}
13781378

1379+
struct ChangeLayoutReturn : public mlir::OpRewritePattern<mlir::ReturnOp> {
1380+
using OpRewritePattern::OpRewritePattern;
1381+
1382+
mlir::LogicalResult
1383+
matchAndRewrite(mlir::ReturnOp op,
1384+
mlir::PatternRewriter &rewriter) const override {
1385+
if (op.operands().empty())
1386+
return mlir::failure();
1387+
1388+
auto func = op->getParentOfType<mlir::FuncOp>();
1389+
if (!func || !func.isPrivate() || !llvm::hasSingleElement(func.getBody()))
1390+
return mlir::failure();
1391+
1392+
auto mod = func->getParentOfType<mlir::ModuleOp>();
1393+
assert(mod);
1394+
1395+
auto funcUses = mlir::SymbolTable::getSymbolUses(func, mod);
1396+
if (!funcUses)
1397+
return mlir::failure();
1398+
1399+
for (auto use : *funcUses)
1400+
if (!mlir::isa<mlir::CallOp>(use.getUser()))
1401+
return mlir::failure();
1402+
1403+
auto loc = op->getLoc();
1404+
auto args = op.operands();
1405+
auto count = static_cast<unsigned>(args.size());
1406+
llvm::SmallVector<mlir::Value> newArgs(args.begin(), args.end());
1407+
llvm::SmallVector<int64_t> shape;
1408+
1409+
bool changed = false;
1410+
for (auto i : llvm::seq(0u, count)) {
1411+
auto arg = args[i];
1412+
auto retType = arg.getType().dyn_cast<mlir::MemRefType>();
1413+
if (!retType)
1414+
continue;
1415+
1416+
auto cast = arg.getDefiningOp<mlir::memref::CastOp>();
1417+
if (!cast)
1418+
continue;
1419+
1420+
auto src = cast.source();
1421+
auto srcType = src.getType().cast<mlir::MemRefType>();
1422+
assert(srcType.getElementType() == retType.getElementType());
1423+
1424+
auto srcLayout = srcType.getLayout();
1425+
auto srcShape = srcType.getShape();
1426+
auto dstShape = retType.getShape();
1427+
assert(srcShape.size() == dstShape.size());
1428+
auto rank = static_cast<unsigned>(srcShape.size());
1429+
shape.resize(rank);
1430+
for (auto j : llvm::seq(0u, rank)) {
1431+
if (!mlir::ShapedType::isDynamic(dstShape[j])) {
1432+
shape[j] = dstShape[j];
1433+
} else if (!mlir::ShapedType::isDynamic(srcShape[j])) {
1434+
shape[j] = srcShape[j];
1435+
} else {
1436+
shape[j] = mlir::ShapedType::kDynamicSize;
1437+
}
1438+
}
1439+
1440+
auto newType = mlir::MemRefType::get(shape, srcType.getElementType(),
1441+
srcLayout, srcType.getMemorySpace());
1442+
if (newType == retType)
1443+
continue;
1444+
1445+
auto newArg = rewriter.create<mlir::memref::CastOp>(loc, newType, src);
1446+
newArgs[i] = newArg;
1447+
changed = true;
1448+
}
1449+
1450+
if (!changed)
1451+
return mlir::failure();
1452+
1453+
rewriter.replaceOpWithNewOp<mlir::ReturnOp>(op, newArgs);
1454+
1455+
auto newFuncType = [&]() {
1456+
auto origType = func.getType();
1457+
mlir::ValueRange r(newArgs);
1458+
return mlir::FunctionType::get(getContext(), origType.getInputs(),
1459+
r.getTypes());
1460+
}();
1461+
1462+
rewriter.updateRootInPlace(
1463+
func, [&]() { func.typeAttr(mlir::TypeAttr::get(newFuncType)); });
1464+
1465+
llvm::SmallVector<mlir::CallOp> calls;
1466+
for (auto use : *funcUses) {
1467+
auto call = mlir::cast<mlir::CallOp>(use.getUser());
1468+
calls.emplace_back(call);
1469+
}
1470+
1471+
for (auto call : calls) {
1472+
rewriter.setInsertionPoint(call);
1473+
auto callLoc = call->getLoc();
1474+
auto oldResults = call.getResults();
1475+
auto newResults =
1476+
rewriter.create<mlir::CallOp>(callLoc, func, call.operands())
1477+
.getResults();
1478+
newArgs.assign(newResults.begin(), newResults.end());
1479+
for (auto i : llvm::seq(0u, count)) {
1480+
auto oldType = oldResults[i].getType();
1481+
auto newType = newArgs[i].getType();
1482+
if (oldType != newType)
1483+
newArgs[i] = rewriter.create<mlir::memref::CastOp>(callLoc, oldType,
1484+
newArgs[i]);
1485+
}
1486+
rewriter.replaceOp(call, newArgs);
1487+
}
1488+
1489+
return mlir::success();
1490+
}
1491+
};
1492+
1493+
struct OptimizeStridedLayoutPass
1494+
: public mlir::PassWrapper<OptimizeStridedLayoutPass,
1495+
mlir::OperationPass<mlir::ModuleOp>> {
1496+
void runOnOperation() override {
1497+
auto *context = &getContext();
1498+
mlir::RewritePatternSet patterns(context);
1499+
1500+
plier::populateCanonicalizationPatterns(*context, patterns);
1501+
1502+
patterns.insert<ChangeLayoutReturn>(context);
1503+
1504+
(void)mlir::applyPatternsAndFoldGreedily(getOperation(),
1505+
std::move(patterns));
1506+
}
1507+
};
1508+
13791509
struct FinalizeStridedLayoutPass
13801510
: public mlir::PassWrapper<FinalizeStridedLayoutPass,
13811511
mlir::OperationPass<>> {
@@ -2716,7 +2846,7 @@ static void populatePlierToLinalgOptPipeline(mlir::OpPassManager &pm) {
27162846

27172847
pm.addNestedPass<mlir::FuncOp>(std::make_unique<CloneArgsPass>());
27182848
pm.addPass(std::make_unique<MakeStridedLayoutPass>());
2719-
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
2849+
pm.addPass(std::make_unique<OptimizeStridedLayoutPass>());
27202850
pm.addNestedPass<mlir::FuncOp>(std::make_unique<FinalizeStridedLayoutPass>());
27212851
pm.addNestedPass<mlir::FuncOp>(
27222852
mlir::bufferization::createBufferDeallocationPass());

0 commit comments

Comments
 (0)