Skip to content

Commit 7994333

Browse files
committed
[MLIR][Python] enable type stub auto-generation
1 parent 6166fda commit 7994333

File tree

13 files changed

+57
-3164
lines changed

13 files changed

+57
-3164
lines changed

mlir/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,10 @@ configure_file(
191191

192192
set(MLIR_BINDINGS_PYTHON_NB_DOMAIN "mlir"
193193
CACHE STRING "nanobind domain for MLIR python bindings.")
194+
set(MLIR_PYTHON_PACKAGE_PREFIX "mlir"
195+
CACHE STRING "Specifies that all MLIR packages are co-located under the
196+
`MLIR_PYTHON_PACKAGE_PREFIX` top level package (the API has been
197+
embedded in a relocatable way).")
194198
set(MLIR_ENABLE_BINDINGS_PYTHON 0 CACHE BOOL
195199
"Enables building of Python bindings.")
196200
set(MLIR_BINDINGS_PYTHON_INSTALL_PREFIX "python_packages/mlir_core/mlir" CACHE STRING

mlir/cmake/modules/AddMLIRPython.cmake

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,38 @@ function(declare_mlir_python_sources name)
9999
endif()
100100
endfunction()
101101

102+
function(generate_type_stubs module_name depends_target output_dir)
103+
if(EXISTS ${nanobind_DIR}/../src/stubgen.py)
104+
set(NB_STUBGEN "${nanobind_DIR}/../src/stubgen.py")
105+
elseif(EXISTS ${nanobind_DIR}/../stubgen.py)
106+
set(NB_STUBGEN "${nanobind_DIR}/../stubgen.py")
107+
else()
108+
message(FATAL_ERROR "generate_type_stubs(): could not locate 'stubgen.py'!")
109+
endif()
110+
111+
set(NB_STUBGEN_CMD
112+
"${Python_EXECUTABLE}"
113+
"${NB_STUBGEN}"
114+
--module
115+
"${MLIR_PYTHON_PACKAGE_PREFIX}._mlir_libs.${module_name}"
116+
-i
117+
"${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}/.."
118+
--recursive
119+
--include-private
120+
--output-dir
121+
"${output_dir}")
122+
123+
set(NB_STUBGEN_OUTPUT "${output_dir}/${module_name}.pyi")
124+
add_custom_command(
125+
OUTPUT ${NB_STUBGEN_OUTPUT}
126+
COMMAND ${NB_STUBGEN_CMD}
127+
WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}"
128+
DEPENDS ${depends_target})
129+
set(_name "MLIRPythonModuleStubs_${module_name}")
130+
add_custom_target("${_name}" ALL DEPENDS ${NB_STUBGEN_OUTPUT})
131+
set(NB_STUBGEN_CUSTOM_TARGET "${_name}" PARENT_SCOPE)
132+
endfunction()
133+
102134
# Function: declare_mlir_python_extension
103135
# Declares a buildable python extension from C++ source files. The built
104136
# module is considered a python source file and included as everything else.
@@ -243,6 +275,17 @@ function(add_mlir_python_modules name)
243275
)
244276
add_dependencies(${modules_target} ${_extension_target})
245277
mlir_python_setup_extension_rpath(${_extension_target})
278+
generate_type_stubs(
279+
${_module_name}
280+
${_extension_target}
281+
"${CMAKE_CURRENT_SOURCE_DIR}/mlir/_mlir_libs/_mlir"
282+
)
283+
declare_mlir_python_sources("_${_module_name}_type_stub_gen"
284+
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
285+
ADD_TO_PARENT "${sources_target}"
286+
SOURCES_GLOB "_mlir_libs/${_module_name}/**/*.pyi"
287+
)
288+
add_dependencies("${modules_target}" "${NB_STUBGEN_CUSTOM_TARGET}")
246289
else()
247290
message(SEND_ERROR "Unrecognized source type '${_source_type}' for python source target ${sources_target}")
248291
return()

mlir/examples/standalone/python/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@ include(AddMLIRPython)
22

33
# Specifies that all MLIR packages are co-located under the `mlir_standalone`
44
# top level package (the API has been embedded in a relocatable way).
5-
# TODO: Add an upstream cmake param for this vs having a global here.
6-
add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=mlir_standalone.")
5+
add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=${MLIR_PYTHON_PACKAGE_PREFIX}.")
76

87

98
################################################################################

mlir/python/CMakeLists.txt

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
include(AddMLIRPython)
22

3+
# Specifies that all MLIR packages are co-located under the `mlir_standalone`
4+
# top level package (the API has been embedded in a relocatable way).
5+
add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=${MLIR_PYTHON_PACKAGE_PREFIX}.")
6+
37
################################################################################
48
# Structural groupings.
59
################################################################################
@@ -23,11 +27,6 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
2327
passmanager.py
2428
rewrite.py
2529
dialects/_ods_common.py
26-
27-
# The main _mlir module has submodules: include stubs from each.
28-
_mlir_libs/_mlir/__init__.pyi
29-
_mlir_libs/_mlir/ir.pyi
30-
_mlir_libs/_mlir/passmanager.pyi
3130
)
3231

3332
declare_mlir_python_sources(MLIRPythonSources.Core.Python.Extras
@@ -43,7 +42,6 @@ declare_mlir_python_sources(MLIRPythonSources.ExecutionEngine
4342
ADD_TO_PARENT MLIRPythonSources
4443
SOURCES
4544
execution_engine.py
46-
_mlir_libs/_mlirExecutionEngine.pyi
4745
SOURCES_GLOB
4846
runtime/*.py
4947
)
@@ -195,7 +193,6 @@ declare_mlir_dialect_python_bindings(
195193
TD_FILE dialects/TransformOps.td
196194
SOURCES
197195
dialects/transform/__init__.py
198-
_mlir_libs/_mlir/dialects/transform/__init__.pyi
199196
DIALECT_NAME transform
200197
GEN_ENUM_BINDINGS_TD_FILE
201198
"../../include/mlir/Dialect/Transform/IR/TransformAttrs.td"
@@ -367,8 +364,7 @@ declare_mlir_python_sources(
367364
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
368365
GEN_ENUM_BINDINGS
369366
SOURCES
370-
dialects/quant.py
371-
_mlir_libs/_mlir/dialects/quant.pyi)
367+
dialects/quant.py)
372368

373369
declare_mlir_dialect_python_bindings(
374370
ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -384,7 +380,6 @@ declare_mlir_dialect_python_bindings(
384380
TD_FILE dialects/PDLOps.td
385381
SOURCES
386382
dialects/pdl.py
387-
_mlir_libs/_mlir/dialects/pdl.pyi
388383
DIALECT_NAME pdl)
389384

390385
declare_mlir_dialect_python_bindings(
@@ -809,7 +804,7 @@ endif()
809804
add_mlir_python_common_capi_library(MLIRPythonCAPI
810805
INSTALL_COMPONENT MLIRPythonModules
811806
INSTALL_DESTINATION "${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}/_mlir_libs"
812-
OUTPUT_DIRECTORY "${MLIR_BINARY_DIR}/python_packages/mlir_core/mlir/_mlir_libs"
807+
OUTPUT_DIRECTORY "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}/_mlir_libs"
813808
RELATIVE_INSTALL_ROOT "../../../.."
814809
DECLARED_HEADERS
815810
MLIRPythonCAPI.HeaderSources
@@ -838,7 +833,7 @@ endif()
838833
################################################################################
839834

840835
add_mlir_python_modules(MLIRPythonModules
841-
ROOT_PREFIX "${MLIR_BINARY_DIR}/python_packages/mlir_core/mlir"
836+
ROOT_PREFIX "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}"
842837
INSTALL_PREFIX "${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}"
843838
DECLARED_SOURCES
844839
MLIRPythonSources
@@ -847,4 +842,3 @@ add_mlir_python_modules(MLIRPythonModules
847842
COMMON_CAPI_LINK_LIBS
848843
MLIRPythonCAPI
849844
)
850-
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
_mlir/**/*.pyi

mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi

Lines changed: 0 additions & 12 deletions
This file was deleted.

mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi

Lines changed: 0 additions & 63 deletions
This file was deleted.

mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi

Lines changed: 0 additions & 142 deletions
This file was deleted.

0 commit comments

Comments
 (0)