@@ -699,33 +699,6 @@ struct ChangeLayoutFromCast
699
699
}
700
700
};
701
701
702
- struct ChangeLayoutSignCast : public mlir ::OpRewritePattern<plier::SignCastOp> {
703
- using OpRewritePattern::OpRewritePattern;
704
-
705
- mlir::LogicalResult
706
- matchAndRewrite (plier::SignCastOp op,
707
- mlir::PatternRewriter &rewriter) const override {
708
- auto cl = op.value ().getDefiningOp <plier::ChangeLayoutOp>();
709
- if (!cl)
710
- return mlir::failure ();
711
-
712
- auto src = cl.source ();
713
- auto srcType = src.getType ().cast <mlir::MemRefType>();
714
- auto oldType = op.getType ().cast <mlir::MemRefType>();
715
- auto newType = mlir::MemRefType::get (
716
- srcType.getShape (), oldType.getElementType (), srcType.getLayout ());
717
-
718
- auto loc = op.getLoc ();
719
- auto newOp = rewriter.createOrFold <plier::SignCastOp>(loc, newType, src);
720
-
721
- if (oldType != newType)
722
- newOp = rewriter.createOrFold <plier::ChangeLayoutOp>(loc, oldType, newOp);
723
-
724
- rewriter.replaceOp (op, newOp);
725
- return mlir::success ();
726
- }
727
- };
728
-
729
702
struct ChangeLayoutLoad : public mlir ::OpRewritePattern<mlir::memref::LoadOp> {
730
703
using OpRewritePattern::OpRewritePattern;
731
704
@@ -1107,11 +1080,10 @@ void ChangeLayoutOp::getCanonicalizationPatterns(
1107
1080
results.insert <
1108
1081
ChangeLayoutIdentity, ChangeLayoutDim, ChangeLayoutExtractMetadata,
1109
1082
ChangeLayoutClone, PropagateCloneType, ChangeLayoutCast,
1110
- ChangeLayoutFromCast, ChangeLayoutSignCast, ChangeLayoutLoad,
1111
- ChangeLayoutStore, ChangeLayoutSubview, ChangeLayoutLinalgGeneric,
1112
- ChangeLayoutLinalgFill, ChangeLayoutIf, ChangeLayout1DReshape,
1113
- ChangeLayoutSliceGetItem, ChangeLayoutCopy, ChangeLayoutExpandShape>(
1114
- context);
1083
+ ChangeLayoutFromCast, ChangeLayoutLoad, ChangeLayoutStore,
1084
+ ChangeLayoutSubview, ChangeLayoutLinalgGeneric, ChangeLayoutLinalgFill,
1085
+ ChangeLayoutIf, ChangeLayout1DReshape, ChangeLayoutSliceGetItem,
1086
+ ChangeLayoutCopy, ChangeLayoutExpandShape>(context);
1115
1087
}
1116
1088
1117
1089
static mlir::Value propagateCasts (mlir::Value val, mlir::Type thisType);
@@ -1198,65 +1170,31 @@ struct SignCastUndefPropagate
1198
1170
}
1199
1171
};
1200
1172
1201
- struct SignCastTensorCastPropagate
1202
- : public mlir::OpRewritePattern<plier::SignCastOp> {
1203
- using OpRewritePattern::OpRewritePattern;
1204
-
1205
- mlir::LogicalResult
1206
- matchAndRewrite (plier::SignCastOp op,
1207
- mlir::PatternRewriter &rewriter) const override {
1208
- auto tensorCast = op.value ().getDefiningOp <mlir::tensor::CastOp>();
1209
- if (!tensorCast)
1210
- return mlir::failure ();
1211
-
1212
- auto srcType = tensorCast.source ().getType ().cast <mlir::TensorType>();
1213
- auto dstType = tensorCast.getType ().cast <mlir::TensorType>();
1214
- if (srcType.getElementType () != dstType.getElementType () ||
1215
- !srcType.hasRank () || !dstType.hasRank ())
1216
- return mlir::failure ();
1217
-
1218
- auto finalType = op.getType ().cast <mlir::TensorType>();
1219
- auto finalElemType = finalType.getElementType ();
1220
-
1221
- auto newSrcType = srcType.clone (finalElemType);
1222
- auto newDstType = dstType.clone (finalElemType);
1223
-
1224
- auto loc = op.getLoc ();
1225
- auto casted = rewriter.createOrFold <plier::SignCastOp>(loc, newSrcType,
1226
- tensorCast.source ());
1227
- rewriter.replaceOpWithNewOp <mlir::tensor::CastOp>(op, newDstType, casted);
1228
-
1229
- return mlir::success ();
1230
- }
1231
- };
1232
-
1233
- struct SignCastMemrefCastPropagate
1234
- : public mlir::OpRewritePattern<plier::SignCastOp> {
1235
- using OpRewritePattern::OpRewritePattern;
1173
+ template <typename CastOp>
1174
+ struct SignCastCastPropagate : public mlir ::OpRewritePattern<CastOp> {
1175
+ using mlir::OpRewritePattern<CastOp>::OpRewritePattern;
1236
1176
1237
1177
mlir::LogicalResult
1238
- matchAndRewrite (plier::SignCastOp op,
1239
- mlir::PatternRewriter &rewriter) const override {
1240
- auto memrefCast = op.value ().getDefiningOp <mlir::memref::CastOp>();
1241
- if (!memrefCast)
1178
+ matchAndRewrite (CastOp op, mlir::PatternRewriter &rewriter) const override {
1179
+ auto signCast = op.source ().template getDefiningOp <plier::SignCastOp>();
1180
+ if (!signCast)
1242
1181
return mlir::failure ();
1243
1182
1244
- auto srcType = memrefCast .source ().getType ().cast <mlir::MemRefType >();
1245
- auto dstType = memrefCast .getType ().cast <mlir::MemRefType >();
1183
+ auto srcType = op .source ().getType ().template cast <mlir::ShapedType >();
1184
+ auto dstType = op .getType ().template cast <mlir::ShapedType >();
1246
1185
if (srcType.getElementType () != dstType.getElementType () ||
1247
1186
!srcType.hasRank () || !dstType.hasRank ())
1248
1187
return mlir::failure ();
1249
1188
1250
- auto finalType = op.getType ().cast <mlir::MemRefType>();
1189
+ auto src = signCast.value ();
1190
+ auto finalType = src.getType ().template cast <mlir::ShapedType>();
1251
1191
auto finalElemType = finalType.getElementType ();
1252
1192
1253
- auto newSrcType = srcType.clone (finalElemType);
1254
1193
auto newDstType = dstType.clone (finalElemType);
1255
1194
1256
1195
auto loc = op.getLoc ();
1257
- auto casted = rewriter.createOrFold <plier::SignCastOp>(loc, newSrcType,
1258
- memrefCast.source ());
1259
- rewriter.replaceOpWithNewOp <mlir::memref::CastOp>(op, newDstType, casted);
1196
+ auto cast = rewriter.createOrFold <CastOp>(loc, newDstType, src);
1197
+ rewriter.replaceOpWithNewOp <plier::SignCastOp>(op, dstType, cast);
1260
1198
1261
1199
return mlir::success ();
1262
1200
}
@@ -1336,82 +1274,50 @@ struct SignCastTensorCollapseShapePropagate
1336
1274
}
1337
1275
};
1338
1276
1339
- struct SignCastTensorToMemrefPropagate
1340
- : public mlir::OpRewritePattern<plier::SignCastOp > {
1341
- using OpRewritePattern::OpRewritePattern;
1277
+ template < typename BuffOp>
1278
+ struct SignCastBuferizationPropagate : public mlir ::OpRewritePattern<BuffOp > {
1279
+ using mlir:: OpRewritePattern<BuffOp> ::OpRewritePattern;
1342
1280
1343
1281
mlir::LogicalResult
1344
- matchAndRewrite (plier::SignCastOp op,
1345
- mlir::PatternRewriter &rewriter) const override {
1346
- auto toMemref = op. value (). getDefiningOp <mlir::bufferization::ToMemrefOp >();
1347
- if (!toMemref )
1282
+ matchAndRewrite (BuffOp op, mlir::PatternRewriter &rewriter) const override {
1283
+ auto signCast =
1284
+ op-> getOperand ( 0 ). template getDefiningOp <plier::SignCastOp >();
1285
+ if (!signCast )
1348
1286
return mlir::failure ();
1349
1287
1350
- auto tensor = toMemref.tensor ();
1351
- auto tensorType = tensor.getType ().cast <mlir::TensorType>();
1352
- auto dstType = op.getType ().cast <mlir::MemRefType>();
1288
+ auto src = signCast.value ();
1289
+ auto srcType = src.getType ().template cast <mlir::ShapedType>();
1290
+ auto dstType = op.getType ().template cast <mlir::ShapedType>();
1291
+ auto newDstType = dstType.clone (srcType.getElementType ());
1353
1292
1354
- auto newTensorType = tensorType.clone (dstType.getElementType ());
1355
-
1356
- auto loc = toMemref->getLoc ();
1357
- auto newTensor =
1358
- rewriter.create <plier::SignCastOp>(loc, newTensorType, tensor);
1359
- rewriter.replaceOpWithNewOp <mlir::bufferization::ToMemrefOp>(op, dstType,
1360
- newTensor);
1361
- return mlir::success ();
1362
- }
1363
- };
1364
-
1365
- struct SignCastMemrefToTensorPropagate
1366
- : public mlir::OpRewritePattern<plier::SignCastOp> {
1367
- using OpRewritePattern::OpRewritePattern;
1368
-
1369
- mlir::LogicalResult
1370
- matchAndRewrite (plier::SignCastOp op,
1371
- mlir::PatternRewriter &rewriter) const override {
1372
- auto toTensor = op.value ().getDefiningOp <mlir::bufferization::ToTensorOp>();
1373
- if (!toTensor)
1374
- return mlir::failure ();
1375
-
1376
- auto memref = toTensor.memref ();
1377
- auto memrefType = memref.getType ().cast <mlir::MemRefType>();
1378
- auto dstType = op.getType ().cast <mlir::TensorType>();
1379
-
1380
- auto newMemrefType = memrefType.clone (dstType.getElementType ());
1381
-
1382
- auto loc = toTensor->getLoc ();
1383
- auto newMemref =
1384
- rewriter.create <plier::SignCastOp>(loc, newMemrefType, memref);
1385
- rewriter.replaceOpWithNewOp <mlir::bufferization::ToTensorOp>(op, dstType,
1386
- newMemref);
1293
+ auto loc = op->getLoc ();
1294
+ auto res = rewriter.create <BuffOp>(loc, newDstType, src);
1295
+ rewriter.replaceOpWithNewOp <plier::SignCastOp>(op, dstType, res);
1387
1296
return mlir::success ();
1388
1297
}
1389
1298
};
1390
1299
1391
- struct SignCastMemrefSubviewPropagate
1392
- : public mlir::OpRewritePattern<plier::SignCastOp > {
1393
- using OpRewritePattern::OpRewritePattern;
1300
+ template < typename ViewOp, typename ArrType>
1301
+ struct SignCastSubviewPropagate : public mlir ::OpRewritePattern<ViewOp > {
1302
+ using mlir:: OpRewritePattern<ViewOp> ::OpRewritePattern;
1394
1303
1395
1304
mlir::LogicalResult
1396
- matchAndRewrite (plier::SignCastOp op,
1397
- mlir::PatternRewriter &rewriter) const override {
1398
- auto prevOp = op.value ().getDefiningOp <mlir::memref::SubViewOp>();
1399
- if (!prevOp)
1305
+ matchAndRewrite (ViewOp op, mlir::PatternRewriter &rewriter) const override {
1306
+ auto signCast = op.source ().template getDefiningOp <plier::SignCastOp>();
1307
+ if (!signCast)
1400
1308
return mlir::failure ();
1401
1309
1402
- auto src = prevOp.source ();
1403
- auto srcType = src.getType ().cast <mlir::ShapedType>();
1404
- auto dstType = op.getType ().cast <mlir::ShapedType>();
1405
-
1406
- auto newSrcType = srcType.clone (dstType.getElementType ());
1310
+ auto src = signCast.value ();
1311
+ auto srcType = src.getType ().template cast <ArrType>();
1312
+ auto dstType = op.getType ().template cast <ArrType>();
1407
1313
auto newDstType =
1408
- dstType.clone (dstType .getElementType ()).cast <mlir::MemRefType >();
1314
+ dstType.clone (srcType .getElementType ()).template cast <ArrType >();
1409
1315
1410
- auto loc = prevOp ->getLoc ();
1411
- auto newSrc = rewriter. create <plier::SignCastOp>(loc, newSrcType, src);
1412
- rewriter.replaceOpWithNewOp <mlir::memref::SubViewOp>(
1413
- op, newDstType, newSrc, prevOp. getMixedOffsets (),
1414
- prevOp. getMixedSizes (), prevOp. getMixedStrides () );
1316
+ auto loc = op ->getLoc ();
1317
+ auto res =
1318
+ rewriter.create <ViewOp>(loc, newDstType, src, op. getMixedOffsets (),
1319
+ op. getMixedSizes (), op. getMixedStrides ());
1320
+ rewriter. replaceOpWithNewOp <plier::SignCastOp>(op, dstType, res );
1415
1321
return mlir::success ();
1416
1322
}
1417
1323
};
@@ -1427,19 +1333,20 @@ struct SignCastForPropagate : public mlir::OpRewritePattern<mlir::scf::ForOp> {
1427
1333
auto termResults = term.getResults ();
1428
1334
auto initArgs = op.getInitArgs ();
1429
1335
auto count = static_cast <unsigned >(initArgs.size ());
1430
-
1431
1336
assert (termResults.size () == count);
1337
+
1338
+ auto loc = op->getLoc ();
1432
1339
llvm::SmallVector<mlir::Value> newInitArgs (count);
1433
1340
bool needUpdate = false ;
1434
1341
for (auto i : llvm::seq (0u , count)) {
1435
1342
auto initArg = initArgs[i];
1436
1343
auto yieldArg = termResults[i];
1437
1344
assert (initArg.getType () == yieldArg.getType ());
1438
- auto initCast = initArg.getDefiningOp <plier::SignCastOp>();
1439
1345
auto yieldCast = yieldArg.getDefiningOp <plier::SignCastOp>();
1440
- if (initCast && yieldCast &&
1441
- initCast.value ().getType () == yieldCast.value ().getType ()) {
1442
- newInitArgs[i] = initCast.value ();
1346
+ if (yieldCast) {
1347
+ auto newType = yieldCast.value ().getType ();
1348
+ newInitArgs[i] =
1349
+ rewriter.create <plier::SignCastOp>(loc, newType, initArg);
1443
1350
needUpdate = true ;
1444
1351
} else {
1445
1352
newInitArgs[i] = initArg;
@@ -1476,14 +1383,14 @@ struct SignCastForPropagate : public mlir::OpRewritePattern<mlir::scf::ForOp> {
1476
1383
auto val = mapping.lookupOrDefault (termResults[i]);
1477
1384
auto newType = newInitArgs[i].getType ();
1478
1385
if (val.getType () != newType)
1479
- val = builder. create <plier::SignCastOp>(loc, newType, val );
1386
+ val = val. getDefiningOp <plier::SignCastOp>(). value ( );
1480
1387
1388
+ assert (val.getType () == newType);
1481
1389
newYieldArgs[i] = val;
1482
1390
}
1483
1391
builder.create <mlir::scf::YieldOp>(loc, newYieldArgs);
1484
1392
};
1485
1393
1486
- auto loc = op->getLoc ();
1487
1394
auto newResults = rewriter
1488
1395
.create <mlir::scf::ForOp>(
1489
1396
loc, op.getLowerBound (), op.getUpperBound (),
@@ -1512,12 +1419,18 @@ void SignCastOp::getCanonicalizationPatterns(::mlir::RewritePatternSet &results,
1512
1419
results.insert <
1513
1420
SignCastDimPropagate<mlir::tensor::DimOp>,
1514
1421
SignCastDimPropagate<mlir::memref::DimOp>, SignCastUndefPropagate,
1515
- SignCastTensorCastPropagate, SignCastMemrefCastPropagate,
1422
+ SignCastCastPropagate<mlir::tensor::CastOp>,
1423
+ SignCastCastPropagate<mlir::memref::CastOp>,
1424
+ SignCastCastPropagate<plier::ChangeLayoutOp>,
1516
1425
SignCastAllocPropagate<mlir::memref::AllocOp>,
1517
1426
SignCastAllocPropagate<mlir::memref::AllocaOp>,
1518
1427
SignCastTensorFromElementsPropagate, SignCastTensorCollapseShapePropagate,
1519
- SignCastTensorToMemrefPropagate, SignCastMemrefToTensorPropagate,
1520
- SignCastMemrefSubviewPropagate, SignCastForPropagate>(context);
1428
+ SignCastBuferizationPropagate<mlir::bufferization::ToMemrefOp>,
1429
+ SignCastBuferizationPropagate<mlir::bufferization::ToTensorOp>,
1430
+ SignCastSubviewPropagate<mlir::tensor::ExtractSliceOp,
1431
+ mlir::RankedTensorType>,
1432
+ SignCastSubviewPropagate<mlir::memref::SubViewOp, mlir::MemRefType>,
1433
+ SignCastForPropagate>(context);
1521
1434
}
1522
1435
1523
1436
void ExtractMemrefMetadataOp::build (::mlir::OpBuilder &odsBuilder,
0 commit comments