Skip to content

Commit 6e530d2

Browse files
committed
[MLIR][Python] enable precise registration
1 parent 5815846 commit 6e530d2

File tree

13 files changed

+206
-33
lines changed

13 files changed

+206
-33
lines changed

mlir/include/mlir-c/Bindings/Python/Interop.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@
8484
#define MLIR_PYTHON_CAPSULE_VALUE MAKE_MLIR_PYTHON_QUALNAME("ir.Value._CAPIPtr")
8585
#define MLIR_PYTHON_CAPSULE_TYPEID \
8686
MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID._CAPIPtr")
87+
#define MLIR_PYTHON_CAPSULE_DIALECT_HANDLE \
88+
MAKE_MLIR_PYTHON_QUALNAME("ir.DialectHandle._CAPIPtr")
8789

8890
/** Attribute on MLIR Python objects that expose their C-API pointer.
8991
* This will be a type-specific capsule created as per one of the helpers
@@ -457,6 +459,13 @@ static inline MlirValue mlirPythonCapsuleToValue(PyObject *capsule) {
457459
return value;
458460
}
459461

462+
static inline MlirDialectHandle
463+
mlirPythonCapsuleToDialectHandle(PyObject *capsule) {
464+
void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_DIALECT_HANDLE);
465+
MlirDialectHandle handle = {ptr};
466+
return handle;
467+
}
468+
460469
#ifdef __cplusplus
461470
}
462471
#endif
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
//===-- mlir-c/Dialect/Builtin.h - C API for Builtin dialect ------*- C -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4+
// Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// This header declares the C interface for registering and accessing the
11+
// Builtin dialect. A dialect should be registered with a context to make it
12+
// available to users of the context. These users must load the dialect
13+
// before using any of its attributes, operations or types. Parser and pass
14+
// manager can load registered dialects automatically.
15+
//
16+
//===----------------------------------------------------------------------===//
17+
18+
#ifndef MLIR_C_DIALECT_BUILTIN_H
19+
#define MLIR_C_DIALECT_BUILTIN_H
20+
21+
#include "mlir-c/IR.h"
22+
23+
#ifdef __cplusplus
24+
extern "C" {
25+
#endif
26+
27+
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Builtin, builtin);
28+
29+
#ifdef __cplusplus
30+
}
31+
#endif
32+
33+
#endif // MLIR_C_DIALECT_BUILTIN_H

mlir/include/mlir-c/IR.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ DEFINE_C_API_STRUCT(MlirLocation, const void);
6666
DEFINE_C_API_STRUCT(MlirModule, const void);
6767
DEFINE_C_API_STRUCT(MlirType, const void);
6868
DEFINE_C_API_STRUCT(MlirValue, const void);
69+
DEFINE_C_API_STRUCT(MlirDialectHandle, const void);
6970

7071
#undef DEFINE_C_API_STRUCT
7172

@@ -207,11 +208,6 @@ MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect);
207208
// registration schemes.
208209
//===----------------------------------------------------------------------===//
209210

210-
struct MlirDialectHandle {
211-
const void *ptr;
212-
};
213-
typedef struct MlirDialectHandle MlirDialectHandle;
214-
215211
#define MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Name, Namespace) \
216212
MLIR_CAPI_EXPORTED MlirDialectHandle mlirGetDialectHandle__##Namespace##__( \
217213
void)
@@ -233,6 +229,11 @@ MLIR_CAPI_EXPORTED void mlirDialectHandleRegisterDialect(MlirDialectHandle,
233229
MLIR_CAPI_EXPORTED MlirDialect mlirDialectHandleLoadDialect(MlirDialectHandle,
234230
MlirContext);
235231

232+
/// Checks if the dialect handle is null.
233+
static inline bool mlirDialectHandleIsNull(MlirDialectHandle handle) {
234+
return !handle.ptr;
235+
}
236+
236237
//===----------------------------------------------------------------------===//
237238
// DialectRegistry API.
238239
//===----------------------------------------------------------------------===//
@@ -249,6 +250,13 @@ static inline bool mlirDialectRegistryIsNull(MlirDialectRegistry registry) {
249250
MLIR_CAPI_EXPORTED void
250251
mlirDialectRegistryDestroy(MlirDialectRegistry registry);
251252

253+
MLIR_CAPI_EXPORTED int64_t
254+
mlirDialectRegistryGetNumDialectNames(MlirDialectRegistry registry);
255+
256+
MLIR_CAPI_EXPORTED void
257+
mlirDialectRegistryGetDialectNames(MlirDialectRegistry registry,
258+
MlirStringRef *dialectNames);
259+
252260
//===----------------------------------------------------------------------===//
253261
// Location API.
254262
//===----------------------------------------------------------------------===//

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2897,6 +2897,14 @@ maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
28972897
// Populates the core exports of the 'ir' submodule.
28982898
//------------------------------------------------------------------------------
28992899

2900+
MlirDialectHandle createMlirDialectHandleFromCapsule(nb::object capsule) {
2901+
MlirDialectHandle rawRegistry =
2902+
mlirPythonCapsuleToDialectHandle(capsule.ptr());
2903+
if (mlirDialectHandleIsNull(rawRegistry))
2904+
throw nb::python_error();
2905+
return rawRegistry;
2906+
}
2907+
29002908
void mlir::python::populateIRCore(nb::module_ &m) {
29012909
// disable leak warnings which tend to be false positives.
29022910
nb::set_leak_warnings(false);
@@ -3126,14 +3134,38 @@ void mlir::python::populateIRCore(nb::module_ &m) {
31263134
},
31273135
nb::sig("def __repr__(self) -> str"));
31283136

3137+
//----------------------------------------------------------------------------
3138+
// Mapping of MlirDialectHandle
3139+
//----------------------------------------------------------------------------
3140+
3141+
nb::class_<MlirDialectHandle>(m, "DialectHandle")
3142+
.def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
3143+
&createMlirDialectHandleFromCapsule);
3144+
31293145
//----------------------------------------------------------------------------
31303146
// Mapping of PyDialectRegistry
31313147
//----------------------------------------------------------------------------
31323148
nb::class_<PyDialectRegistry>(m, "DialectRegistry")
31333149
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule)
31343150
.def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
31353151
&PyDialectRegistry::createFromCapsule)
3136-
.def(nb::init<>());
3152+
.def(nb::init<>())
3153+
.def("insert_dialect",
3154+
[](PyDialectRegistry &self, MlirDialectHandle handle) {
3155+
mlirDialectHandleInsertDialect(handle, self.get());
3156+
})
3157+
.def("insert_dialect",
3158+
[](PyDialectRegistry &self, intptr_t ptr) {
3159+
mlirDialectHandleInsertDialect(
3160+
{reinterpret_cast<const void *>(ptr)}, self.get());
3161+
})
3162+
.def_prop_ro("dialect_names", [](PyDialectRegistry &self) {
3163+
int64_t numDialectNames =
3164+
mlirDialectRegistryGetNumDialectNames(self.get());
3165+
std::vector<MlirStringRef> dialectNames(numDialectNames);
3166+
mlirDialectRegistryGetDialectNames(self.get(), dialectNames.data());
3167+
return dialectNames;
3168+
});
31373169

31383170
//----------------------------------------------------------------------------
31393171
// Mapping of Location

mlir/lib/CAPI/Dialect/Builtin.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
//===- Builtin.cpp - C Interface for Builtin dialect ----------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir-c/Dialect/Builtin.h"
10+
#include "mlir/CAPI/Registration.h"
11+
#include "mlir/IR/BuiltinDialect.h"
12+
13+
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Builtin, builtin, mlir::BuiltinDialect)

mlir/lib/CAPI/Dialect/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@ add_mlir_upstream_c_api_library(MLIRCAPIArith
1616
MLIRArithDialect
1717
)
1818

19+
add_mlir_upstream_c_api_library(MLIRCAPIBuiltin
20+
Builtin.cpp
21+
22+
PARTIAL_SOURCES_INTENDED
23+
LINK_LIBS PUBLIC
24+
MLIRCAPIIR
25+
)
26+
1927
add_mlir_upstream_c_api_library(MLIRCAPIAsync
2028
Async.cpp
2129
AsyncPasses.cpp

mlir/lib/CAPI/IR/IR.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,18 @@ void mlirDialectRegistryDestroy(MlirDialectRegistry registry) {
150150
delete unwrap(registry);
151151
}
152152

153+
int64_t mlirDialectRegistryGetNumDialectNames(MlirDialectRegistry registry) {
154+
auto dialectNames = unwrap(registry)->getDialectNames();
155+
return std::distance(dialectNames.begin(), dialectNames.end());
156+
}
157+
158+
void mlirDialectRegistryGetDialectNames(MlirDialectRegistry registry,
159+
MlirStringRef *dialectNames) {
160+
for (auto [i, location] :
161+
llvm::enumerate(unwrap(registry)->getDialectNames()))
162+
dialectNames[i] = wrap(location);
163+
}
164+
153165
//===----------------------------------------------------------------------===//
154166
// AsmState API.
155167
//===----------------------------------------------------------------------===//

mlir/python/CMakeLists.txt

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -528,31 +528,32 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
528528
MLIRCAPIDebug
529529
MLIRCAPIIR
530530
MLIRCAPIInterfaces
531+
MLIRCAPITransforms
532+
MLIRCAPIBuiltin
531533

532534
# Dialects
533535
MLIRCAPIFunc
534536
)
535537

536-
# This extension exposes an API to register all dialects, extensions, and passes
537-
# packaged in upstream MLIR and it is used for the upstream "mlir" Python
538-
# package. Downstreams will likely want to provide their own and not depend
539-
# on this one, since it links in the world.
540-
# Note that this is not added to any top-level source target for transitive
541-
# inclusion: It must be included explicitly by downstreams if desired. Note that
542-
# this has a very large impact on what gets built/packaged.
543-
declare_mlir_python_extension(MLIRPythonExtension.RegisterEverything
544-
MODULE_NAME _mlirRegisterEverything
545-
ROOT_DIR "${PYTHON_SOURCE_DIR}"
546-
PYTHON_BINDINGS_LIBRARY nanobind
547-
SOURCES
548-
RegisterEverything.cpp
549-
PRIVATE_LINK_LIBS
550-
LLVMSupport
551-
EMBED_CAPI_LINK_LIBS
552-
MLIRCAPIConversion
553-
MLIRCAPITransforms
554-
MLIRCAPIRegisterEverything
555-
)
538+
## This extension exposes an API to register all dialects, extensions, and passes
539+
## packaged in upstream MLIR and it is used for the upstream "mlir" Python
540+
## package. Downstreams will likely want to provide their own and not depend
541+
## on this one, since it links in the world.
542+
## Note that this is not added to any top-level source target for transitive
543+
## inclusion: It must be included explicitly by downstreams if desired. Note that
544+
## this has a very large impact on what gets built/packaged.
545+
#declare_mlir_python_extension(MLIRPythonExtension.RegisterEverything
546+
# MODULE_NAME _mlirRegisterEverything
547+
# ROOT_DIR "${PYTHON_SOURCE_DIR}"
548+
# PYTHON_BINDINGS_LIBRARY nanobind
549+
# SOURCES
550+
# RegisterEverything.cpp
551+
# PRIVATE_LINK_LIBS
552+
# LLVMSupport
553+
# MLIRCAPIConversion
554+
# MLIRCAPITransforms
555+
# MLIRCAPIRegisterEverything
556+
#)
556557

557558
declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind
558559
MODULE_NAME _mlirDialectsLinalg
@@ -871,7 +872,6 @@ add_mlir_python_common_capi_library(${MLIR_PYTHON_CAPI_DYLIB_NAME}
871872
MLIRPythonCAPI.HeaderSources
872873
DECLARED_SOURCES
873874
MLIRPythonSources
874-
MLIRPythonExtension.RegisterEverything
875875
${_ADDL_TEST_SOURCES}
876876
)
877877

@@ -979,7 +979,6 @@ add_mlir_python_modules(MLIRPythonModules
979979
INSTALL_PREFIX "${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}"
980980
DECLARED_SOURCES
981981
MLIRPythonSources
982-
MLIRPythonExtension.RegisterEverything
983982
MLIRPythonExtension.Core.type_stub_gen
984983
MLIRPythonCAPICTypesBinding
985984
${_ADDL_TEST_SOURCES}

mlir/python/mlir/_mlir_libs/_capi.py.in

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,12 @@
55
import ctypes
66
from pathlib import Path
77

8-
_capi = ctypes.CDLL(str(Path(__file__).parent / "@CMAKE_SHARED_LIBRARY_PREFIX@@MLIR_PYTHON_CAPI_DYLIB_NAME@@CMAKE_SHARED_LIBRARY_SUFFIX@"))
8+
_capi = ctypes.CDLL(str(Path(__file__).parent / "@CMAKE_SHARED_LIBRARY_PREFIX@@MLIR_PYTHON_CAPI_DYLIB_NAME@@CMAKE_SHARED_LIBRARY_SUFFIX@"))
9+
10+
PyCapsule_New = ctypes.pythonapi.PyCapsule_New
11+
PyCapsule_New.restype = ctypes.py_object
12+
PyCapsule_New.argtypes = ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p
13+
14+
MLIR_PYTHON_CAPSULE_DIALECT_HANDLE = (
15+
"@[email protected]._CAPIPtr"
16+
).encode()

mlir/test/python/dialects/irdl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from mlir.ir import *
44
from mlir.dialects.irdl import *
5+
import mlir.dialects.arith
56
import sys
67

78

0 commit comments

Comments
 (0)