From bbcf647c1afcdc6f96b2252acc655d99140c73c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Thu, 7 Nov 2024 12:21:17 +0000 Subject: [PATCH] [mlir][python] Make types in register_(dialect|operation) more narrow. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR makes the `pyClass`/`dialectClass` arguments of the pybind11 functions `register_dialect` and `register_operation` as well as their return types more narrow, concretely, a `py::type` instead of a `py::object`. As the name of the arguments indicate, they have to be called with a type instance (a "class"). The PR also updates the typing stubs of these functions (in the corresponding `.pyi` file), such that static type checkers are aware of the changed type. With the previous typing information, `pyright` raised errors on code generated by tablegen. Signed-off-by: Ingo Müller --- mlir/lib/Bindings/Python/MainModule.cpp | 6 +++--- mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 8da1ab16a4514..7c27021902de3 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -58,7 +58,7 @@ PYBIND11_MODULE(_mlir, m) { // Registration decorators. m.def( "register_dialect", - [](py::object pyClass) { + [](py::type pyClass) { std::string dialectNamespace = pyClass.attr("DIALECT_NAMESPACE").cast(); PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); @@ -68,9 +68,9 @@ PYBIND11_MODULE(_mlir, m) { "Class decorator for registering a custom Dialect wrapper"); m.def( "register_operation", - [](const py::object &dialectClass, bool replace) -> py::cpp_function { + [](const py::type &dialectClass, bool replace) -> py::cpp_function { return py::cpp_function( - [dialectClass, replace](py::object opClass) -> py::object { + [dialectClass, replace](py::type opClass) -> py::type { std::string operationName = opClass.attr("OPERATION_NAME").cast(); PyGlobals::get().registerOperationImpl(operationName, opClass, diff --git a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi index 42694747e5f24..03449b70b7fa3 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi @@ -8,5 +8,5 @@ class _Globals: def append_dialect_search_prefix(self, module_name: str) -> None: ... def _check_dialect_module_loaded(self, dialect_namespace: str) -> bool: ... -def register_dialect(dialect_class: type) -> object: ... -def register_operation(dialect_class: type, *, replace: bool = ...) -> object: ... +def register_dialect(dialect_class: type) -> type: ... +def register_operation(dialect_class: type, *, replace: bool = ...) -> type: ...