@@ -529,30 +529,6 @@ struct ChangeLayoutSignCast : public mlir::OpRewritePattern<plier::SignCastOp> {
529
529
}
530
530
};
531
531
532
- struct ChangeLayoutReduceRank
533
- : public mlir::OpRewritePattern<plier::ReduceRankOp> {
534
- using OpRewritePattern::OpRewritePattern;
535
-
536
- mlir::LogicalResult
537
- matchAndRewrite (plier::ReduceRankOp op,
538
- mlir::PatternRewriter &rewriter) const override {
539
- auto cl = op.source ().getDefiningOp <plier::ChangeLayoutOp>();
540
- if (!cl)
541
- return mlir::failure ();
542
-
543
- auto loc = op.getLoc ();
544
- auto newOp = rewriter.createOrFold <plier::ReduceRankOp>(loc, cl.source (),
545
- op.getMapping ());
546
- auto oldType = op.getType ();
547
- auto newType = newOp.getType ();
548
- if (oldType != newType)
549
- newOp = rewriter.createOrFold <plier::ChangeLayoutOp>(loc, oldType, newOp);
550
-
551
- rewriter.replaceOp (op, newOp);
552
- return mlir::success ();
553
- }
554
- };
555
-
556
532
struct ChangeLayoutLoad : public mlir ::OpRewritePattern<mlir::memref::LoadOp> {
557
533
using OpRewritePattern::OpRewritePattern;
558
534
@@ -829,11 +805,19 @@ struct ChangeLayout1DReshape
829
805
ArrayType strides (srcRank, rewriter.getIndexAttr (1 ));
830
806
auto view = rewriter.createOrFold <mlir::memref::SubViewOp>(
831
807
loc, source, offsets, sizes, strides);
832
- auto resType = view.getType ().cast <mlir::MemRefType>();
833
- if (resType.getRank () > dstType.getRank ()) {
834
- // TODO: Rank-reducing subview
835
- const int32_t mapping[1 ] = {static_cast <int32_t >(*srcDimIndex)};
836
- view = rewriter.createOrFold <plier::ReduceRankOp>(loc, view, mapping);
808
+ auto dstRank = dstType.getRank ();
809
+ if (srcRank != dstRank) {
810
+ assert (dstRank < srcRank);
811
+ llvm::SmallVector<mlir::OpFoldResult> newOfsets (srcRank,
812
+ rewriter.getIndexAttr (0 ));
813
+ llvm::SmallVector<mlir::OpFoldResult> newStrides (
814
+ srcRank, rewriter.getIndexAttr (1 ));
815
+ auto viewType = view.getType ().cast <mlir::MemRefType>();
816
+ auto reducedType = mlir::memref::SubViewOp::inferRankReducedResultType (
817
+ dstRank, viewType, newOfsets, sizes, newStrides)
818
+ .cast <mlir::MemRefType>();
819
+ view = rewriter.create <mlir::memref::SubViewOp>(
820
+ loc, reducedType, view, newOfsets, sizes, newStrides);
837
821
}
838
822
rewriter.replaceOpWithNewOp <plier::ChangeLayoutOp>(op, dstType, view);
839
823
return mlir::success ();
@@ -923,14 +907,13 @@ struct ChangeLayoutExpandShape
923
907
924
908
void ChangeLayoutOp::getCanonicalizationPatterns (
925
909
::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context) {
926
- results
927
- .insert <ChangeLayoutIdentity, ChangeLayoutReduceRank, ChangeLayoutDim,
928
- ChangeLayoutExtractMetadata, ChangeLayoutClone,
929
- PropagateCloneType, ChangeLayoutCast, ChangeLayoutSignCast,
930
- ChangeLayoutLoad, ChangeLayoutStore, ChangeLayoutSubview,
931
- ChangeLayoutLinalgGeneric, ChangeLayoutLinalgFill, ChangeLayoutIf,
932
- ChangeLayout1DReshape, ChangeLayoutSliceGetItem, ChangeLayoutCopy,
933
- ChangeLayoutExpandShape>(context);
910
+ results.insert <
911
+ ChangeLayoutIdentity, ChangeLayoutDim, ChangeLayoutExtractMetadata,
912
+ ChangeLayoutClone, PropagateCloneType, ChangeLayoutCast,
913
+ ChangeLayoutSignCast, ChangeLayoutLoad, ChangeLayoutStore,
914
+ ChangeLayoutSubview, ChangeLayoutLinalgGeneric, ChangeLayoutLinalgFill,
915
+ ChangeLayoutIf, ChangeLayout1DReshape, ChangeLayoutSliceGetItem,
916
+ ChangeLayoutCopy, ChangeLayoutExpandShape>(context);
934
917
}
935
918
936
919
static mlir::Value propagateCasts (mlir::Value val, mlir::Type thisType);
@@ -1207,31 +1190,6 @@ struct SignCastMemrefToTensorPropagate
1207
1190
}
1208
1191
};
1209
1192
1210
- struct SignCastReduceRankPropagate
1211
- : public mlir::OpRewritePattern<plier::SignCastOp> {
1212
- using OpRewritePattern::OpRewritePattern;
1213
-
1214
- mlir::LogicalResult
1215
- matchAndRewrite (plier::SignCastOp op,
1216
- mlir::PatternRewriter &rewriter) const override {
1217
- auto prevOp = op.value ().getDefiningOp <plier::ReduceRankOp>();
1218
- if (!prevOp)
1219
- return mlir::failure ();
1220
-
1221
- auto src = prevOp.source ();
1222
- auto srcType = src.getType ().cast <mlir::ShapedType>();
1223
- auto dstType = op.getType ().cast <mlir::ShapedType>();
1224
-
1225
- auto newSrcType = srcType.clone (dstType.getElementType ());
1226
-
1227
- auto loc = prevOp->getLoc ();
1228
- auto newSrc = rewriter.create <plier::SignCastOp>(loc, newSrcType, src);
1229
- rewriter.replaceOpWithNewOp <plier::ReduceRankOp>(op, newSrc,
1230
- prevOp.getMapping ());
1231
- return mlir::success ();
1232
- }
1233
- };
1234
-
1235
1193
struct SignCastMemrefSubviewPropagate
1236
1194
: public mlir::OpRewritePattern<plier::SignCastOp> {
1237
1195
using OpRewritePattern::OpRewritePattern;
@@ -1272,175 +1230,7 @@ void SignCastOp::getCanonicalizationPatterns(::mlir::RewritePatternSet &results,
1272
1230
SignCastAllocPropagate<mlir::memref::AllocaOp>,
1273
1231
SignCastTensorFromElementsPropagate, SignCastTensorCollapseShapePropagate,
1274
1232
SignCastTensorToMemrefPropagate, SignCastMemrefToTensorPropagate,
1275
- SignCastReduceRankPropagate, SignCastMemrefSubviewPropagate>(context);
1276
- }
1277
-
1278
- void ReduceRankOp::build (::mlir::OpBuilder &odsBuilder,
1279
- ::mlir::OperationState &odsState, ::mlir::Value src,
1280
- ::mlir::ArrayRef<int32_t > mapping) {
1281
- assert (src.getType ().isa <mlir::ShapedType>());
1282
- auto srcType = src.getType ().cast <mlir::ShapedType>();
1283
- assert (srcType.hasRank ());
1284
- auto srcRank = static_cast <unsigned >(srcType.getRank ());
1285
- assert (!mapping.empty ());
1286
- assert (llvm::all_of (mapping, [&](int32_t val) {
1287
- return val >= 0 && val < static_cast <int32_t >(srcRank);
1288
- }));
1289
- auto mapAttr = odsBuilder.getI32ArrayAttr (mapping);
1290
- auto srcShape = srcType.getShape ();
1291
- llvm::SmallVector<int64_t > shape (mapping.size ());
1292
- for (auto it : llvm::enumerate (mapping))
1293
- shape[it.index ()] = srcShape[static_cast <size_t >(it.value ())];
1294
-
1295
- if (auto tensorType = srcType.dyn_cast <mlir::RankedTensorType>()) {
1296
- auto retType = mlir::RankedTensorType::get (
1297
- shape, tensorType.getElementType (), tensorType.getEncoding ());
1298
- build (odsBuilder, odsState, retType, src, mapAttr);
1299
- } else if (auto memrefType = srcType.dyn_cast <mlir::MemRefType>()) {
1300
- auto affineMap = [&]() {
1301
- mlir::AffineMap ret;
1302
- if (!memrefType.getLayout ().isIdentity ()) {
1303
- auto affineMap = memrefType.getLayout ().getAffineMap ();
1304
- auto context = odsBuilder.getContext ();
1305
- llvm::SmallVector<mlir::AffineExpr> dimReplacements (srcRank);
1306
- llvm::SmallVector<mlir::AffineExpr> symReplacements (srcRank + 1 );
1307
- symReplacements[0 ] = mlir::getAffineSymbolExpr (0 , context);
1308
- for (auto i : llvm::seq (0u , srcRank)) {
1309
- auto it = llvm::find (mapping, i);
1310
- if (it != mapping.end ()) {
1311
- auto srcIndex = static_cast <unsigned >(it - mapping.begin ());
1312
- dimReplacements[i] = mlir::getAffineDimExpr (srcIndex, context);
1313
- symReplacements[i + 1 ] =
1314
- mlir::getAffineSymbolExpr (srcIndex + 1 , context);
1315
- } else {
1316
- dimReplacements[i] = mlir::getAffineConstantExpr (0 , context);
1317
- symReplacements[i + 1 ] = mlir::getAffineConstantExpr (0 , context);
1318
- }
1319
- }
1320
- auto dstRank = static_cast <unsigned >(mapping.size ());
1321
- auto resMap = affineMap.replaceDimsAndSymbols (
1322
- dimReplacements, symReplacements, dstRank, dstRank + 1 );
1323
- ret = mlir::simplifyAffineMap (resMap);
1324
- }
1325
- return ret;
1326
- }();
1327
-
1328
- auto retType =
1329
- mlir::MemRefType::get (shape, memrefType.getElementType (), affineMap,
1330
- memrefType.getMemorySpace ());
1331
- build (odsBuilder, odsState, retType, src, mapAttr);
1332
- } else {
1333
- llvm_unreachable (" ReduceRankOp: Invalid src type" );
1334
- }
1335
- }
1336
-
1337
- mlir::OpFoldResult
1338
- ReduceRankOp::fold (llvm::ArrayRef<mlir::Attribute> /* operands*/ ) {
1339
- auto src = source ();
1340
- if (src.getType () == getType ()) {
1341
- return src;
1342
- }
1343
- return nullptr ;
1344
- }
1345
-
1346
- llvm::SmallVector<int32_t > ReduceRankOp::getMapping () {
1347
- auto m = mapping ();
1348
- llvm::SmallVector<int32_t > ret (m.size ());
1349
- llvm::transform (m, ret.begin (), [](mlir::Attribute a) {
1350
- return a.cast <mlir::IntegerAttr>().getValue ().getSExtValue ();
1351
- });
1352
- return ret;
1353
- }
1354
-
1355
- namespace {
1356
- template <typename Op>
1357
- struct ReduceRankDimPropagate : public mlir ::OpRewritePattern<Op> {
1358
- using mlir::OpRewritePattern<Op>::OpRewritePattern;
1359
-
1360
- mlir::LogicalResult
1361
- matchAndRewrite (Op op, mlir::PatternRewriter &rewriter) const override {
1362
- auto index = mlir::getConstantIntValue (op.index ());
1363
- if (!index)
1364
- return mlir::failure ();
1365
-
1366
- auto prev = op.source ().template getDefiningOp <plier::ReduceRankOp>();
1367
- if (!prev)
1368
- return mlir::failure ();
1369
-
1370
- auto mappedArg = prev.mapping ()[*index]
1371
- .template cast <mlir::IntegerAttr>()
1372
- .getValue ()
1373
- .getSExtValue ();
1374
- rewriter.replaceOpWithNewOp <Op>(op, prev.source (), mappedArg);
1375
- return mlir::success ();
1376
- }
1377
- };
1378
-
1379
- static auto mapReduceRankIndices (mlir::OpBuilder &builder, mlir::Location loc,
1380
- plier::ReduceRankOp src,
1381
- mlir::ValueRange srcIndices) {
1382
- auto srcMemref = src.getViewSource ();
1383
- auto srcMemrefType = srcMemref.getType ().cast <mlir::MemRefType>();
1384
- auto rank = static_cast <unsigned >(srcMemrefType.getRank ());
1385
- auto zero = builder.createOrFold <mlir::arith::ConstantIndexOp>(loc, 0 );
1386
- auto mapping = src.getMapping ();
1387
- llvm::SmallVector<mlir::Value> indices (rank);
1388
- for (auto i : llvm::seq (0u , rank)) {
1389
- auto it = llvm::find (mapping, static_cast <int32_t >(i));
1390
- if (mapping.end () == it) {
1391
- indices[i] = zero;
1392
- } else {
1393
- auto dstIndex = static_cast <size_t >(it - mapping.begin ());
1394
- indices[i] = srcIndices[dstIndex];
1395
- }
1396
- }
1397
- return indices;
1398
- }
1399
-
1400
- struct ReduceRankLoadPropagate
1401
- : public mlir::OpRewritePattern<mlir::memref::LoadOp> {
1402
- using OpRewritePattern::OpRewritePattern;
1403
-
1404
- mlir::LogicalResult
1405
- matchAndRewrite (mlir::memref::LoadOp op,
1406
- mlir::PatternRewriter &rewriter) const override {
1407
- auto src = op.memref ().getDefiningOp <plier::ReduceRankOp>();
1408
- if (!src)
1409
- return mlir::failure ();
1410
-
1411
- auto indices =
1412
- mapReduceRankIndices (rewriter, op.getLoc (), src, op.indices ());
1413
- rewriter.replaceOpWithNewOp <mlir::memref::LoadOp>(op, src.getViewSource (),
1414
- indices);
1415
- return mlir::success ();
1416
- }
1417
- };
1418
-
1419
- struct ReduceRankStorePropagate
1420
- : public mlir::OpRewritePattern<mlir::memref::StoreOp> {
1421
- using OpRewritePattern::OpRewritePattern;
1422
-
1423
- mlir::LogicalResult
1424
- matchAndRewrite (mlir::memref::StoreOp op,
1425
- mlir::PatternRewriter &rewriter) const override {
1426
- auto src = op.memref ().getDefiningOp <plier::ReduceRankOp>();
1427
- if (!src)
1428
- return mlir::failure ();
1429
-
1430
- auto indices =
1431
- mapReduceRankIndices (rewriter, op.getLoc (), src, op.indices ());
1432
- rewriter.replaceOpWithNewOp <mlir::memref::StoreOp>(
1433
- op, op.value (), src.getViewSource (), indices);
1434
- return mlir::success ();
1435
- }
1436
- };
1437
- } // namespace
1438
-
1439
- void ReduceRankOp::getCanonicalizationPatterns (
1440
- ::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context) {
1441
- results.insert <ReduceRankDimPropagate<mlir::tensor::DimOp>,
1442
- ReduceRankDimPropagate<mlir::memref::DimOp>,
1443
- ReduceRankLoadPropagate, ReduceRankStorePropagate>(context);
1233
+ SignCastMemrefSubviewPropagate>(context);
1444
1234
}
1445
1235
1446
1236
void ExtractMemrefMetadataOp::build (::mlir::OpBuilder &odsBuilder,
@@ -1502,26 +1292,6 @@ ExtractMemrefMetadataOp::fold(llvm::ArrayRef<mlir::Attribute> /*operands*/) {
1502
1292
return getResult ();
1503
1293
}
1504
1294
1505
- if (auto reduceRank = src.getDefiningOp <plier::ReduceRankOp>()) {
1506
- auto newSrc = reduceRank.source ();
1507
- if (idx == -1 ) {
1508
- sourceMutable ().assign (newSrc);
1509
- return getResult ();
1510
- }
1511
-
1512
- auto mapping = reduceRank.getMapping ();
1513
- if (static_cast <unsigned >(idx) < mapping.size ()) {
1514
- auto newIdx = mapping[static_cast <unsigned >(idx)];
1515
- assert (newIdx >= 0 );
1516
- sourceMutable ().assign (newSrc);
1517
- auto type = dimIndexAttr ().getType ();
1518
- dimIndexAttr (mlir::IntegerAttr::get (type, newIdx));
1519
- return getResult ();
1520
- }
1521
-
1522
- return nullptr ;
1523
- }
1524
-
1525
1295
if (auto cast = src.getDefiningOp <mlir::memref::CastOp>()) {
1526
1296
auto castSrc = cast.source ();
1527
1297
auto castSrcType = castSrc.getType ().cast <mlir::ShapedType>();
0 commit comments