Skip to content

Commit 898ee75

Browse files
authored
SignCastOp propagation refactoring (#182)
1 parent f348c1e commit 898ee75

File tree

2 files changed

+62
-150
lines changed

2 files changed

+62
-150
lines changed

mlir/lib/dialect/plier_util/dialect.cpp

Lines changed: 62 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -699,33 +699,6 @@ struct ChangeLayoutFromCast
699699
}
700700
};
701701

702-
struct ChangeLayoutSignCast : public mlir::OpRewritePattern<plier::SignCastOp> {
703-
using OpRewritePattern::OpRewritePattern;
704-
705-
mlir::LogicalResult
706-
matchAndRewrite(plier::SignCastOp op,
707-
mlir::PatternRewriter &rewriter) const override {
708-
auto cl = op.value().getDefiningOp<plier::ChangeLayoutOp>();
709-
if (!cl)
710-
return mlir::failure();
711-
712-
auto src = cl.source();
713-
auto srcType = src.getType().cast<mlir::MemRefType>();
714-
auto oldType = op.getType().cast<mlir::MemRefType>();
715-
auto newType = mlir::MemRefType::get(
716-
srcType.getShape(), oldType.getElementType(), srcType.getLayout());
717-
718-
auto loc = op.getLoc();
719-
auto newOp = rewriter.createOrFold<plier::SignCastOp>(loc, newType, src);
720-
721-
if (oldType != newType)
722-
newOp = rewriter.createOrFold<plier::ChangeLayoutOp>(loc, oldType, newOp);
723-
724-
rewriter.replaceOp(op, newOp);
725-
return mlir::success();
726-
}
727-
};
728-
729702
struct ChangeLayoutLoad : public mlir::OpRewritePattern<mlir::memref::LoadOp> {
730703
using OpRewritePattern::OpRewritePattern;
731704

@@ -1107,11 +1080,10 @@ void ChangeLayoutOp::getCanonicalizationPatterns(
11071080
results.insert<
11081081
ChangeLayoutIdentity, ChangeLayoutDim, ChangeLayoutExtractMetadata,
11091082
ChangeLayoutClone, PropagateCloneType, ChangeLayoutCast,
1110-
ChangeLayoutFromCast, ChangeLayoutSignCast, ChangeLayoutLoad,
1111-
ChangeLayoutStore, ChangeLayoutSubview, ChangeLayoutLinalgGeneric,
1112-
ChangeLayoutLinalgFill, ChangeLayoutIf, ChangeLayout1DReshape,
1113-
ChangeLayoutSliceGetItem, ChangeLayoutCopy, ChangeLayoutExpandShape>(
1114-
context);
1083+
ChangeLayoutFromCast, ChangeLayoutLoad, ChangeLayoutStore,
1084+
ChangeLayoutSubview, ChangeLayoutLinalgGeneric, ChangeLayoutLinalgFill,
1085+
ChangeLayoutIf, ChangeLayout1DReshape, ChangeLayoutSliceGetItem,
1086+
ChangeLayoutCopy, ChangeLayoutExpandShape>(context);
11151087
}
11161088

11171089
static mlir::Value propagateCasts(mlir::Value val, mlir::Type thisType);
@@ -1198,65 +1170,31 @@ struct SignCastUndefPropagate
11981170
}
11991171
};
12001172

1201-
struct SignCastTensorCastPropagate
1202-
: public mlir::OpRewritePattern<plier::SignCastOp> {
1203-
using OpRewritePattern::OpRewritePattern;
1204-
1205-
mlir::LogicalResult
1206-
matchAndRewrite(plier::SignCastOp op,
1207-
mlir::PatternRewriter &rewriter) const override {
1208-
auto tensorCast = op.value().getDefiningOp<mlir::tensor::CastOp>();
1209-
if (!tensorCast)
1210-
return mlir::failure();
1211-
1212-
auto srcType = tensorCast.source().getType().cast<mlir::TensorType>();
1213-
auto dstType = tensorCast.getType().cast<mlir::TensorType>();
1214-
if (srcType.getElementType() != dstType.getElementType() ||
1215-
!srcType.hasRank() || !dstType.hasRank())
1216-
return mlir::failure();
1217-
1218-
auto finalType = op.getType().cast<mlir::TensorType>();
1219-
auto finalElemType = finalType.getElementType();
1220-
1221-
auto newSrcType = srcType.clone(finalElemType);
1222-
auto newDstType = dstType.clone(finalElemType);
1223-
1224-
auto loc = op.getLoc();
1225-
auto casted = rewriter.createOrFold<plier::SignCastOp>(loc, newSrcType,
1226-
tensorCast.source());
1227-
rewriter.replaceOpWithNewOp<mlir::tensor::CastOp>(op, newDstType, casted);
1228-
1229-
return mlir::success();
1230-
}
1231-
};
1232-
1233-
struct SignCastMemrefCastPropagate
1234-
: public mlir::OpRewritePattern<plier::SignCastOp> {
1235-
using OpRewritePattern::OpRewritePattern;
1173+
template <typename CastOp>
1174+
struct SignCastCastPropagate : public mlir::OpRewritePattern<CastOp> {
1175+
using mlir::OpRewritePattern<CastOp>::OpRewritePattern;
12361176

12371177
mlir::LogicalResult
1238-
matchAndRewrite(plier::SignCastOp op,
1239-
mlir::PatternRewriter &rewriter) const override {
1240-
auto memrefCast = op.value().getDefiningOp<mlir::memref::CastOp>();
1241-
if (!memrefCast)
1178+
matchAndRewrite(CastOp op, mlir::PatternRewriter &rewriter) const override {
1179+
auto signCast = op.source().template getDefiningOp<plier::SignCastOp>();
1180+
if (!signCast)
12421181
return mlir::failure();
12431182

1244-
auto srcType = memrefCast.source().getType().cast<mlir::MemRefType>();
1245-
auto dstType = memrefCast.getType().cast<mlir::MemRefType>();
1183+
auto srcType = op.source().getType().template cast<mlir::ShapedType>();
1184+
auto dstType = op.getType().template cast<mlir::ShapedType>();
12461185
if (srcType.getElementType() != dstType.getElementType() ||
12471186
!srcType.hasRank() || !dstType.hasRank())
12481187
return mlir::failure();
12491188

1250-
auto finalType = op.getType().cast<mlir::MemRefType>();
1189+
auto src = signCast.value();
1190+
auto finalType = src.getType().template cast<mlir::ShapedType>();
12511191
auto finalElemType = finalType.getElementType();
12521192

1253-
auto newSrcType = srcType.clone(finalElemType);
12541193
auto newDstType = dstType.clone(finalElemType);
12551194

12561195
auto loc = op.getLoc();
1257-
auto casted = rewriter.createOrFold<plier::SignCastOp>(loc, newSrcType,
1258-
memrefCast.source());
1259-
rewriter.replaceOpWithNewOp<mlir::memref::CastOp>(op, newDstType, casted);
1196+
auto cast = rewriter.createOrFold<CastOp>(loc, newDstType, src);
1197+
rewriter.replaceOpWithNewOp<plier::SignCastOp>(op, dstType, cast);
12601198

12611199
return mlir::success();
12621200
}
@@ -1336,82 +1274,50 @@ struct SignCastTensorCollapseShapePropagate
13361274
}
13371275
};
13381276

1339-
struct SignCastTensorToMemrefPropagate
1340-
: public mlir::OpRewritePattern<plier::SignCastOp> {
1341-
using OpRewritePattern::OpRewritePattern;
1277+
template <typename BuffOp>
1278+
struct SignCastBuferizationPropagate : public mlir::OpRewritePattern<BuffOp> {
1279+
using mlir::OpRewritePattern<BuffOp>::OpRewritePattern;
13421280

13431281
mlir::LogicalResult
1344-
matchAndRewrite(plier::SignCastOp op,
1345-
mlir::PatternRewriter &rewriter) const override {
1346-
auto toMemref = op.value().getDefiningOp<mlir::bufferization::ToMemrefOp>();
1347-
if (!toMemref)
1282+
matchAndRewrite(BuffOp op, mlir::PatternRewriter &rewriter) const override {
1283+
auto signCast =
1284+
op->getOperand(0).template getDefiningOp<plier::SignCastOp>();
1285+
if (!signCast)
13481286
return mlir::failure();
13491287

1350-
auto tensor = toMemref.tensor();
1351-
auto tensorType = tensor.getType().cast<mlir::TensorType>();
1352-
auto dstType = op.getType().cast<mlir::MemRefType>();
1288+
auto src = signCast.value();
1289+
auto srcType = src.getType().template cast<mlir::ShapedType>();
1290+
auto dstType = op.getType().template cast<mlir::ShapedType>();
1291+
auto newDstType = dstType.clone(srcType.getElementType());
13531292

1354-
auto newTensorType = tensorType.clone(dstType.getElementType());
1355-
1356-
auto loc = toMemref->getLoc();
1357-
auto newTensor =
1358-
rewriter.create<plier::SignCastOp>(loc, newTensorType, tensor);
1359-
rewriter.replaceOpWithNewOp<mlir::bufferization::ToMemrefOp>(op, dstType,
1360-
newTensor);
1361-
return mlir::success();
1362-
}
1363-
};
1364-
1365-
struct SignCastMemrefToTensorPropagate
1366-
: public mlir::OpRewritePattern<plier::SignCastOp> {
1367-
using OpRewritePattern::OpRewritePattern;
1368-
1369-
mlir::LogicalResult
1370-
matchAndRewrite(plier::SignCastOp op,
1371-
mlir::PatternRewriter &rewriter) const override {
1372-
auto toTensor = op.value().getDefiningOp<mlir::bufferization::ToTensorOp>();
1373-
if (!toTensor)
1374-
return mlir::failure();
1375-
1376-
auto memref = toTensor.memref();
1377-
auto memrefType = memref.getType().cast<mlir::MemRefType>();
1378-
auto dstType = op.getType().cast<mlir::TensorType>();
1379-
1380-
auto newMemrefType = memrefType.clone(dstType.getElementType());
1381-
1382-
auto loc = toTensor->getLoc();
1383-
auto newMemref =
1384-
rewriter.create<plier::SignCastOp>(loc, newMemrefType, memref);
1385-
rewriter.replaceOpWithNewOp<mlir::bufferization::ToTensorOp>(op, dstType,
1386-
newMemref);
1293+
auto loc = op->getLoc();
1294+
auto res = rewriter.create<BuffOp>(loc, newDstType, src);
1295+
rewriter.replaceOpWithNewOp<plier::SignCastOp>(op, dstType, res);
13871296
return mlir::success();
13881297
}
13891298
};
13901299

1391-
struct SignCastMemrefSubviewPropagate
1392-
: public mlir::OpRewritePattern<plier::SignCastOp> {
1393-
using OpRewritePattern::OpRewritePattern;
1300+
template <typename ViewOp, typename ArrType>
1301+
struct SignCastSubviewPropagate : public mlir::OpRewritePattern<ViewOp> {
1302+
using mlir::OpRewritePattern<ViewOp>::OpRewritePattern;
13941303

13951304
mlir::LogicalResult
1396-
matchAndRewrite(plier::SignCastOp op,
1397-
mlir::PatternRewriter &rewriter) const override {
1398-
auto prevOp = op.value().getDefiningOp<mlir::memref::SubViewOp>();
1399-
if (!prevOp)
1305+
matchAndRewrite(ViewOp op, mlir::PatternRewriter &rewriter) const override {
1306+
auto signCast = op.source().template getDefiningOp<plier::SignCastOp>();
1307+
if (!signCast)
14001308
return mlir::failure();
14011309

1402-
auto src = prevOp.source();
1403-
auto srcType = src.getType().cast<mlir::ShapedType>();
1404-
auto dstType = op.getType().cast<mlir::ShapedType>();
1405-
1406-
auto newSrcType = srcType.clone(dstType.getElementType());
1310+
auto src = signCast.value();
1311+
auto srcType = src.getType().template cast<ArrType>();
1312+
auto dstType = op.getType().template cast<ArrType>();
14071313
auto newDstType =
1408-
dstType.clone(dstType.getElementType()).cast<mlir::MemRefType>();
1314+
dstType.clone(srcType.getElementType()).template cast<ArrType>();
14091315

1410-
auto loc = prevOp->getLoc();
1411-
auto newSrc = rewriter.create<plier::SignCastOp>(loc, newSrcType, src);
1412-
rewriter.replaceOpWithNewOp<mlir::memref::SubViewOp>(
1413-
op, newDstType, newSrc, prevOp.getMixedOffsets(),
1414-
prevOp.getMixedSizes(), prevOp.getMixedStrides());
1316+
auto loc = op->getLoc();
1317+
auto res =
1318+
rewriter.create<ViewOp>(loc, newDstType, src, op.getMixedOffsets(),
1319+
op.getMixedSizes(), op.getMixedStrides());
1320+
rewriter.replaceOpWithNewOp<plier::SignCastOp>(op, dstType, res);
14151321
return mlir::success();
14161322
}
14171323
};
@@ -1427,19 +1333,20 @@ struct SignCastForPropagate : public mlir::OpRewritePattern<mlir::scf::ForOp> {
14271333
auto termResults = term.getResults();
14281334
auto initArgs = op.getInitArgs();
14291335
auto count = static_cast<unsigned>(initArgs.size());
1430-
14311336
assert(termResults.size() == count);
1337+
1338+
auto loc = op->getLoc();
14321339
llvm::SmallVector<mlir::Value> newInitArgs(count);
14331340
bool needUpdate = false;
14341341
for (auto i : llvm::seq(0u, count)) {
14351342
auto initArg = initArgs[i];
14361343
auto yieldArg = termResults[i];
14371344
assert(initArg.getType() == yieldArg.getType());
1438-
auto initCast = initArg.getDefiningOp<plier::SignCastOp>();
14391345
auto yieldCast = yieldArg.getDefiningOp<plier::SignCastOp>();
1440-
if (initCast && yieldCast &&
1441-
initCast.value().getType() == yieldCast.value().getType()) {
1442-
newInitArgs[i] = initCast.value();
1346+
if (yieldCast) {
1347+
auto newType = yieldCast.value().getType();
1348+
newInitArgs[i] =
1349+
rewriter.create<plier::SignCastOp>(loc, newType, initArg);
14431350
needUpdate = true;
14441351
} else {
14451352
newInitArgs[i] = initArg;
@@ -1476,14 +1383,14 @@ struct SignCastForPropagate : public mlir::OpRewritePattern<mlir::scf::ForOp> {
14761383
auto val = mapping.lookupOrDefault(termResults[i]);
14771384
auto newType = newInitArgs[i].getType();
14781385
if (val.getType() != newType)
1479-
val = builder.create<plier::SignCastOp>(loc, newType, val);
1386+
val = val.getDefiningOp<plier::SignCastOp>().value();
14801387

1388+
assert(val.getType() == newType);
14811389
newYieldArgs[i] = val;
14821390
}
14831391
builder.create<mlir::scf::YieldOp>(loc, newYieldArgs);
14841392
};
14851393

1486-
auto loc = op->getLoc();
14871394
auto newResults = rewriter
14881395
.create<mlir::scf::ForOp>(
14891396
loc, op.getLowerBound(), op.getUpperBound(),
@@ -1512,12 +1419,18 @@ void SignCastOp::getCanonicalizationPatterns(::mlir::RewritePatternSet &results,
15121419
results.insert<
15131420
SignCastDimPropagate<mlir::tensor::DimOp>,
15141421
SignCastDimPropagate<mlir::memref::DimOp>, SignCastUndefPropagate,
1515-
SignCastTensorCastPropagate, SignCastMemrefCastPropagate,
1422+
SignCastCastPropagate<mlir::tensor::CastOp>,
1423+
SignCastCastPropagate<mlir::memref::CastOp>,
1424+
SignCastCastPropagate<plier::ChangeLayoutOp>,
15161425
SignCastAllocPropagate<mlir::memref::AllocOp>,
15171426
SignCastAllocPropagate<mlir::memref::AllocaOp>,
15181427
SignCastTensorFromElementsPropagate, SignCastTensorCollapseShapePropagate,
1519-
SignCastTensorToMemrefPropagate, SignCastMemrefToTensorPropagate,
1520-
SignCastMemrefSubviewPropagate, SignCastForPropagate>(context);
1428+
SignCastBuferizationPropagate<mlir::bufferization::ToMemrefOp>,
1429+
SignCastBuferizationPropagate<mlir::bufferization::ToTensorOp>,
1430+
SignCastSubviewPropagate<mlir::tensor::ExtractSliceOp,
1431+
mlir::RankedTensorType>,
1432+
SignCastSubviewPropagate<mlir::memref::SubViewOp, mlir::MemRefType>,
1433+
SignCastForPropagate>(context);
15211434
}
15221435

15231436
void ExtractMemrefMetadataOp::build(::mlir::OpBuilder &odsBuilder,

numba_dpcomp/numba_dpcomp/mlir/tests/test_numba_parfor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ def _gen_tests():
5252
]
5353

5454
xfail_tests = {
55-
"test_prange09",
5655
"test_prange03sub",
5756
"test_prange03div",
5857
"test_prange07",

0 commit comments

Comments
 (0)