Skip to content

Commit f493b65

Browse files
[mlir:python] Add manual typing annotations to mlir.register_* functions.
This PR adds a manual typing annotations to the `register_operation` and `register_(type|value)_caster` functions in the main `mlir` module. Since those functions return the result `nb::cpp_function`, which is of type `nb::object`, the automatic typing annocations are of the form `def f() -> object`. This isn't particularly precise and leads to type checking errors when the functions are used. Manually defining the annotation with `nb::sig` solves the problem. Signed-off-by: Ingo Müller <[email protected]>
1 parent e66e8aa commit f493b65

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

mlir/include/mlir/Bindings/Python/Nanobind.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <nanobind/stl/string_view.h>
3131
#include <nanobind/stl/tuple.h>
3232
#include <nanobind/stl/vector.h>
33+
#include <nanobind/typing.h>
3334
#if defined(__clang__) || defined(__GNUC__)
3435
#pragma GCC diagnostic pop
3536
#endif

mlir/lib/Bindings/Python/MainModule.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ using namespace mlir::python;
2424

2525
NB_MODULE(_mlir, m) {
2626
m.doc() = "MLIR Python Native Extension";
27+
m.attr("T") = nb::type_var("T");
28+
m.attr("U") = nb::type_var("U");
2729

2830
nb::class_<PyGlobals>(m, "_Globals")
2931
.def_prop_rw("dialect_search_modules",
@@ -102,6 +104,8 @@ NB_MODULE(_mlir, m) {
102104
return opClass;
103105
});
104106
},
107+
nb::sig("def register_operation(dialect_class: type, *, "
108+
"replace: bool = False) -> typing.Callable[[type[T]], type[T]]"),
105109
"dialect_class"_a, nb::kw_only(), "replace"_a = false,
106110
"Produce a class decorator for registering an Operation class as part of "
107111
"a dialect");
@@ -114,6 +118,10 @@ NB_MODULE(_mlir, m) {
114118
return typeCaster;
115119
});
116120
},
121+
nb::sig("def register_type_caster(typeid: _mlir.ir.TypeID, *, "
122+
"replace: bool = False) "
123+
"-> typing.Callable[[typing.Callable[[T], U]], "
124+
"typing.Callable[[T], U]]"),
117125
"typeid"_a, nb::kw_only(), "replace"_a = false,
118126
"Register a type caster for casting MLIR types to custom user types.");
119127
m.def(
@@ -126,6 +134,10 @@ NB_MODULE(_mlir, m) {
126134
return valueCaster;
127135
});
128136
},
137+
nb::sig("def register_value_caster(typeid: _mlir.ir.TypeID, *, "
138+
"replace: bool = False) "
139+
"-> typing.Callable[[typing.Callable[[T], U]], "
140+
"typing.Callable[[T], U]]"),
129141
"typeid"_a, nb::kw_only(), "replace"_a = false,
130142
"Register a value caster for casting MLIR values to custom user values.");
131143

0 commit comments

Comments
 (0)