Skip to content

Commit 6fc4065

Browse files
authored
[python] Fix literal return issue (#192)
1 parent 806a0d3 commit 6fc4065

File tree

4 files changed

+75
-11
lines changed

4 files changed

+75
-11
lines changed

numba_dpcomp/numba_dpcomp/mlir/tests/test_basic.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,15 @@ def py_func(a):
6767
assert_equal(py_func(val), jit_func(val))
6868

6969

70+
@pytest.mark.parametrize("val", _test_values)
71+
def test_ret_literal(val):
72+
def py_func():
73+
return val
74+
75+
jit_func = njit(py_func)
76+
assert_equal(py_func(), jit_func())
77+
78+
7079
@parametrize_function_variants(
7180
"py_func",
7281
[

numba_dpcomp/numba_dpcomp/mlir/tests/test_numba_parfor.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def _gen_tests():
7575
"test_unsigned_refusal_to_vectorize", # Need to hook asm checks
7676
"test_vectorizer_fastmath_asm", # Need to hook asm checks
7777
"test_kde_example", # List suport
78-
"test_prange27", # Literal return issue
7978
"test_simple01", # Empty shape not failed
8079
"test_kmeans", # List suport
8180
"test_ndarray_fill", # array.fill
@@ -140,7 +139,6 @@ def _gen_tests():
140139
"test_issue5942_2", # invalid result
141140
"test_parfor_ufunc_typing", # np.isinf
142141
"test_issue_5098", # list support and more
143-
"test_parfor_slice27", # Literal return issue
144142
"test_ufunc_expr", # np.bitwise_and(
145143
"test_parfor_generate_fuse", # operand #0 does not dominate this use
146144
"test_parfor_slice7", # array.transpose

numba_dpcomp/numba_dpcomp/mlir_compiler/lib/lowering.cpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -272,17 +272,21 @@ struct PlierLowerer final {
272272
return builder.create<plier::ArgOp>(
273273
getCurrentLoc(), index, target.attr("name").cast<std::string>());
274274
}
275-
if (py::isinstance(value, insts.Expr)) {
275+
276+
if (py::isinstance(value, insts.Expr))
276277
return lowerExpr(value);
277-
}
278-
if (py::isinstance(value, insts.Var)) {
278+
279+
if (py::isinstance(value, insts.Var))
279280
return loadvar(value);
280-
}
281-
if (py::isinstance(value, insts.Const)) {
281+
282+
if (py::isinstance(value, insts.Const))
282283
return getConst(value.attr("value"));
283-
}
284+
284285
if (py::isinstance(value, insts.Global) ||
285286
py::isinstance(value, insts.FreeVar)) {
287+
auto constVal = getConstOrNull(value.attr("value"));
288+
if (constVal)
289+
return constVal;
286290
auto name = value.attr("name").cast<std::string>();
287291
return builder.create<plier::GlobalOp>(getCurrentLoc(), name);
288292
}
@@ -542,7 +546,7 @@ struct PlierLowerer final {
542546
builder.create<mlir::cf::BranchOp>(getCurrentLoc(), mlir::None, block);
543547
}
544548

545-
mlir::Value getConst(py::handle val) {
549+
mlir::Value getConstOrNull(py::handle val) {
546550
auto getVal = [&](mlir::Attribute attr) {
547551
return builder.create<plier::ConstOp>(getCurrentLoc(), attr);
548552
};
@@ -558,8 +562,15 @@ struct PlierLowerer final {
558562
if (py::isinstance<py::none>(val))
559563
return getVal(builder.getUnitAttr());
560564

561-
plier::reportError(llvm::Twine("get_const unhandled type \"") +
562-
py::str(val.get_type()).cast<std::string>() + "\"");
565+
return {};
566+
}
567+
568+
mlir::Value getConst(py::handle val) {
569+
auto ret = getConstOrNull(val);
570+
if (!ret)
571+
plier::reportError(llvm::Twine("get_const unhandled type \"") +
572+
py::str(val.get_type()).cast<std::string>() + "\"");
573+
return ret;
563574
}
564575

565576
mlir::FunctionType getFuncType(py::handle fnargs, py::handle restype) {

numba_dpcomp/numba_dpcomp/mlir_compiler/lib/pipelines/plier_to_std.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1385,13 +1385,59 @@ void PlierToStdPass::runOnOperation() {
13851385
signalPassFailure();
13861386
}
13871387

1388+
struct ConvertLiteralTypesPass
1389+
: public mlir::PassWrapper<PlierToStdPass,
1390+
mlir::OperationPass<mlir::ModuleOp>> {
1391+
virtual void
1392+
getDependentDialects(mlir::DialectRegistry &registry) const override {
1393+
registry.insert<mlir::func::FuncDialect>();
1394+
}
1395+
1396+
void runOnOperation() override {
1397+
mlir::TypeConverter typeConverter;
1398+
// Convert unknown types to itself
1399+
typeConverter.addConversion([](mlir::Type type) { return type; });
1400+
1401+
auto context = &getContext();
1402+
typeConverter.addConversion(
1403+
[](plier::LiteralType type) { return type.getValue().getType(); });
1404+
1405+
auto materializeCast =
1406+
[](mlir::OpBuilder &builder, mlir::Type type, mlir::ValueRange inputs,
1407+
mlir::Location loc) -> llvm::Optional<mlir::Value> {
1408+
if (inputs.size() == 1)
1409+
return builder
1410+
.create<mlir::UnrealizedConversionCastOp>(loc, type, inputs.front())
1411+
.getResult(0);
1412+
1413+
return llvm::None;
1414+
};
1415+
typeConverter.addArgumentMaterialization(materializeCast);
1416+
typeConverter.addSourceMaterialization(materializeCast);
1417+
typeConverter.addTargetMaterialization(materializeCast);
1418+
1419+
mlir::RewritePatternSet patterns(context);
1420+
mlir::ConversionTarget target(*context);
1421+
1422+
plier::populateControlFlowTypeConversionRewritesAndTarget(typeConverter,
1423+
patterns, target);
1424+
plier::populateTupleTypeConversionRewritesAndTarget(typeConverter, patterns,
1425+
target);
1426+
1427+
if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
1428+
std::move(patterns))))
1429+
signalPassFailure();
1430+
}
1431+
};
1432+
13881433
static void populatePlierToStdPipeline(mlir::OpPassManager &pm) {
13891434
pm.addPass(mlir::createCanonicalizerPass());
13901435
pm.addPass(std::make_unique<PlierToStdPass>());
13911436
pm.addPass(mlir::createCanonicalizerPass());
13921437
pm.addPass(std::make_unique<BuiltinCallsLoweringPass>());
13931438
pm.addPass(plier::createForceInlinePass());
13941439
pm.addPass(mlir::createSymbolDCEPass());
1440+
pm.addPass(std::make_unique<ConvertLiteralTypesPass>());
13951441
pm.addPass(mlir::createCanonicalizerPass());
13961442
}
13971443
} // namespace

0 commit comments

Comments
 (0)