Skip to content

Commit ef18139

Browse files
authored
[python] Tuple concat support (#193)
1 parent 6fc4065 commit ef18139

File tree

3 files changed

+81
-4
lines changed

3 files changed

+81
-4
lines changed

numba_dpcomp/numba_dpcomp/mlir/tests/test_basic.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,24 @@ def py_func(a, b, c):
432432
assert_equal(py_func(a, b, c), jit_func(a, b, c))
433433

434434

435+
@parametrize_function_variants(
436+
"py_func",
437+
[
438+
"lambda a, b, c: (a, b) + (c,)",
439+
"lambda a, b, c: (a,) + (b, c)",
440+
"lambda a, b, c: (a, b, c) + ()",
441+
"lambda a, b, c: () + (a, b, c)",
442+
],
443+
)
444+
@pytest.mark.parametrize(
445+
"a, b, c",
446+
itertools.product(_tuple_test_values, _tuple_test_values, _tuple_test_values),
447+
)
448+
def test_tuple_concat(py_func, a, b, c):
449+
jit_func = njit(py_func)
450+
assert_equal(py_func(a, b, c), jit_func(a, b, c))
451+
452+
435453
@pytest.mark.parametrize(
436454
"a, b, c",
437455
itertools.product(_tuple_test_values, _tuple_test_values, _tuple_test_values),

numba_dpcomp/numba_dpcomp/mlir/tests/test_numba_parfor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ def _gen_tests():
9494
"test_linspace", # np.linspace
9595
"test_std", # array.std
9696
"test_mvdot", # np.dot unsupported args
97-
"test_array_tuple_concat", # tuple concat
9897
"test_namedtuple1", # namedtuple support
9998
"test_0d_broadcast", # np.array
10099
"test_var", # array.var
@@ -109,7 +108,7 @@ def _gen_tests():
109108
"test_tuple_concat_with_reverse_slice", # enumerate
110109
"test_reduce", # functools.reduce
111110
"test_two_d_array_reduction", # 'memref<?x?xi64>' and result type 'memref<?x?xi32>' are cast incompatible
112-
"test_tuple_concat", # tuple concat
111+
"test_tuple_concat", # enumerate
113112
"test_two_d_array_reduction_with_float_sizes", # np.array
114113
"test_two_d_array_reduction_reuse", # 'memref<?x?xi64>' and result type 'memref<?x?xi32>' are cast incompatible
115114
"test_parfor_array_access_lower_slice", # plier.getitem

numba_dpcomp/numba_dpcomp/mlir_compiler/lib/pipelines/plier_to_std.cpp

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -953,6 +953,53 @@ struct BinOpLowering : public mlir::OpConversionPattern<plier::BinOp> {
953953
}
954954
};
955955

956+
struct BinOpTupleLowering : public mlir::OpConversionPattern<plier::BinOp> {
957+
using OpConversionPattern::OpConversionPattern;
958+
959+
mlir::LogicalResult
960+
matchAndRewrite(plier::BinOp op, plier::BinOp::Adaptor adaptor,
961+
mlir::ConversionPatternRewriter &rewriter) const override {
962+
auto lhs = adaptor.lhs();
963+
auto rhs = adaptor.rhs();
964+
auto lhsType = lhs.getType().dyn_cast<mlir::TupleType>();
965+
if (!lhsType)
966+
return mlir::failure();
967+
968+
auto loc = op->getLoc();
969+
if (adaptor.op() == "+") {
970+
auto rhsType = rhs.getType().dyn_cast<mlir::TupleType>();
971+
if (!rhsType)
972+
return mlir::failure();
973+
974+
auto count = lhsType.size() + rhsType.size();
975+
llvm::SmallVector<mlir::Value> newArgs;
976+
llvm::SmallVector<mlir::Type> newTypes;
977+
newArgs.reserve(count);
978+
newTypes.reserve(count);
979+
980+
for (auto &arg : {lhs, rhs}) {
981+
auto type = arg.getType().cast<mlir::TupleType>();
982+
for (auto i : llvm::seq<size_t>(0, type.size())) {
983+
auto elemType = type.getType(i);
984+
auto ind = rewriter.create<mlir::arith::ConstantIndexOp>(
985+
loc, static_cast<int64_t>(i));
986+
auto elem =
987+
rewriter.create<plier::GetItemOp>(loc, elemType, arg, ind);
988+
newArgs.emplace_back(elem);
989+
newTypes.emplace_back(elemType);
990+
}
991+
}
992+
993+
auto newTupleType = mlir::TupleType::get(getContext(), newTypes);
994+
rewriter.replaceOpWithNewOp<plier::BuildTupleOp>(op, newTupleType,
995+
newArgs);
996+
return mlir::success();
997+
}
998+
999+
return mlir::failure();
1000+
}
1001+
};
1002+
9561003
static mlir::Value negate(mlir::PatternRewriter &rewriter, mlir::Location loc,
9571004
mlir::Value val, mlir::Type resType) {
9581005
val = doCast(rewriter, loc, val, resType);
@@ -1322,9 +1369,21 @@ void PlierToStdPass::runOnOperation() {
13221369
plier::LiteralType>();
13231370
};
13241371

1372+
auto isTuple = [&](mlir::Type t) -> bool {
1373+
if (!t)
1374+
return false;
1375+
1376+
auto res = typeConverter.convertType(t);
1377+
return res && res.isa<mlir::TupleType>();
1378+
};
1379+
13251380
target.addDynamicallyLegalOp<plier::BinOp>([&](plier::BinOp op) {
1326-
return !isNum(op.lhs().getType()) || !isNum(op.rhs().getType()) ||
1327-
!isNum(op.getType());
1381+
auto lhsType = op.lhs().getType();
1382+
auto rhsType = op.rhs().getType();
1383+
if (op.op() == "+" && isTuple(lhsType) && isTuple(rhsType))
1384+
return false;
1385+
1386+
return !isNum(lhsType) || !isNum(rhsType) || !isNum(op.getType());
13281387
});
13291388
target.addDynamicallyLegalOp<plier::UnaryOp>([&](plier::UnaryOp op) {
13301389
return !isNum(op.value().getType()) && !isNum(op.getType());
@@ -1364,6 +1423,7 @@ void PlierToStdPass::runOnOperation() {
13641423
patterns.insert<
13651424
// clang-format off
13661425
BinOpLowering,
1426+
BinOpTupleLowering,
13671427
UnaryOpLowering,
13681428
LowerCasts,
13691429
ConstOpLowering,

0 commit comments

Comments
 (0)