@@ -953,6 +953,53 @@ struct BinOpLowering : public mlir::OpConversionPattern<plier::BinOp> {
953
953
}
954
954
};
955
955
956
+ struct BinOpTupleLowering : public mlir ::OpConversionPattern<plier::BinOp> {
957
+ using OpConversionPattern::OpConversionPattern;
958
+
959
+ mlir::LogicalResult
960
+ matchAndRewrite (plier::BinOp op, plier::BinOp::Adaptor adaptor,
961
+ mlir::ConversionPatternRewriter &rewriter) const override {
962
+ auto lhs = adaptor.lhs ();
963
+ auto rhs = adaptor.rhs ();
964
+ auto lhsType = lhs.getType ().dyn_cast <mlir::TupleType>();
965
+ if (!lhsType)
966
+ return mlir::failure ();
967
+
968
+ auto loc = op->getLoc ();
969
+ if (adaptor.op () == " +" ) {
970
+ auto rhsType = rhs.getType ().dyn_cast <mlir::TupleType>();
971
+ if (!rhsType)
972
+ return mlir::failure ();
973
+
974
+ auto count = lhsType.size () + rhsType.size ();
975
+ llvm::SmallVector<mlir::Value> newArgs;
976
+ llvm::SmallVector<mlir::Type> newTypes;
977
+ newArgs.reserve (count);
978
+ newTypes.reserve (count);
979
+
980
+ for (auto &arg : {lhs, rhs}) {
981
+ auto type = arg.getType ().cast <mlir::TupleType>();
982
+ for (auto i : llvm::seq<size_t >(0 , type.size ())) {
983
+ auto elemType = type.getType (i);
984
+ auto ind = rewriter.create <mlir::arith::ConstantIndexOp>(
985
+ loc, static_cast <int64_t >(i));
986
+ auto elem =
987
+ rewriter.create <plier::GetItemOp>(loc, elemType, arg, ind);
988
+ newArgs.emplace_back (elem);
989
+ newTypes.emplace_back (elemType);
990
+ }
991
+ }
992
+
993
+ auto newTupleType = mlir::TupleType::get (getContext (), newTypes);
994
+ rewriter.replaceOpWithNewOp <plier::BuildTupleOp>(op, newTupleType,
995
+ newArgs);
996
+ return mlir::success ();
997
+ }
998
+
999
+ return mlir::failure ();
1000
+ }
1001
+ };
1002
+
956
1003
static mlir::Value negate (mlir::PatternRewriter &rewriter, mlir::Location loc,
957
1004
mlir::Value val, mlir::Type resType) {
958
1005
val = doCast (rewriter, loc, val, resType);
@@ -1322,9 +1369,21 @@ void PlierToStdPass::runOnOperation() {
1322
1369
plier::LiteralType>();
1323
1370
};
1324
1371
1372
+ auto isTuple = [&](mlir::Type t) -> bool {
1373
+ if (!t)
1374
+ return false ;
1375
+
1376
+ auto res = typeConverter.convertType (t);
1377
+ return res && res.isa <mlir::TupleType>();
1378
+ };
1379
+
1325
1380
target.addDynamicallyLegalOp <plier::BinOp>([&](plier::BinOp op) {
1326
- return !isNum (op.lhs ().getType ()) || !isNum (op.rhs ().getType ()) ||
1327
- !isNum (op.getType ());
1381
+ auto lhsType = op.lhs ().getType ();
1382
+ auto rhsType = op.rhs ().getType ();
1383
+ if (op.op () == " +" && isTuple (lhsType) && isTuple (rhsType))
1384
+ return false ;
1385
+
1386
+ return !isNum (lhsType) || !isNum (rhsType) || !isNum (op.getType ());
1328
1387
});
1329
1388
target.addDynamicallyLegalOp <plier::UnaryOp>([&](plier::UnaryOp op) {
1330
1389
return !isNum (op.value ().getType ()) && !isNum (op.getType ());
@@ -1364,6 +1423,7 @@ void PlierToStdPass::runOnOperation() {
1364
1423
patterns.insert <
1365
1424
// clang-format off
1366
1425
BinOpLowering,
1426
+ BinOpTupleLowering,
1367
1427
UnaryOpLowering,
1368
1428
LowerCasts,
1369
1429
ConstOpLowering,
0 commit comments