Skip to content

Commit b1e00f6

Browse files
authored
[mlir][python] Cache import of ir module in type casters. (llvm#160000)
In a JAX benchmark that traces a large language model, this change reduces the time spent in nanobind::module::import_ from 1.2s to 10ms.
1 parent 4a7179f commit b1e00f6

File tree

1 file changed

+71
-29
lines changed

1 file changed

+71
-29
lines changed

mlir/include/mlir/Bindings/Python/NanobindAdaptors.h

Lines changed: 71 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
#ifndef MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H
2020
#define MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H
2121

22+
#include <atomic>
2223
#include <cstdint>
24+
#include <memory>
2325
#include <optional>
2426

2527
#include "mlir-c/Diagnostics.h"
@@ -30,6 +32,57 @@
3032
// clang-format on
3133
#include "llvm/ADT/Twine.h"
3234

35+
namespace mlir {
36+
namespace python {
37+
namespace {
38+
39+
// Safely calls Python initialization code on first use, avoiding deadlocks.
40+
template <typename T>
41+
class SafeInit {
42+
public:
43+
typedef std::unique_ptr<T> (*F)();
44+
45+
explicit SafeInit(F init_fn) : initFn(init_fn) {}
46+
47+
T &get() {
48+
if (T *result = output.load()) {
49+
return *result;
50+
}
51+
52+
// Note: init_fn() may be called multiple times if, for example, the GIL is
53+
// released during its execution. The intended use case is for module
54+
// imports which are safe to perform multiple times. We are careful not to
55+
// hold a lock across init_fn() to avoid lock ordering problems.
56+
std::unique_ptr<T> m = initFn();
57+
{
58+
nanobind::ft_lock_guard lock(mu);
59+
if (T *result = output.load()) {
60+
return *result;
61+
}
62+
T *p = m.release();
63+
output.store(p);
64+
return *p;
65+
}
66+
}
67+
68+
private:
69+
nanobind::ft_mutex mu;
70+
std::atomic<T *> output{nullptr};
71+
F initFn;
72+
};
73+
74+
nanobind::module_ &irModule() {
75+
static SafeInit<nanobind::module_> init([]() {
76+
return std::make_unique<nanobind::module_>(
77+
nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")));
78+
});
79+
return init.get();
80+
}
81+
82+
} // namespace
83+
} // namespace python
84+
} // namespace mlir
85+
3386
// Raw CAPI type casters need to be declared before use, so always include them
3487
// first.
3588
namespace nanobind {
@@ -75,7 +128,7 @@ struct type_caster<MlirAffineMap> {
75128
cleanup_list *cleanup) noexcept {
76129
nanobind::object capsule =
77130
nanobind::steal<nanobind::object>(mlirPythonAffineMapToCapsule(v));
78-
return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
131+
return mlir::python::irModule()
79132
.attr("AffineMap")
80133
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
81134
.release();
@@ -97,7 +150,7 @@ struct type_caster<MlirAttribute> {
97150
cleanup_list *cleanup) noexcept {
98151
nanobind::object capsule =
99152
nanobind::steal<nanobind::object>(mlirPythonAttributeToCapsule(v));
100-
return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
153+
return mlir::python::irModule()
101154
.attr("Attribute")
102155
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
103156
.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
@@ -128,9 +181,7 @@ struct type_caster<MlirContext> {
128181
// TODO: This raises an error of "No current context" currently.
129182
// Update the implementation to pretty-print the helpful error that the
130183
// core implementations print in this case.
131-
src = nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
132-
.attr("Context")
133-
.attr("current");
184+
src = mlir::python::irModule().attr("Context").attr("current");
134185
}
135186
std::optional<nanobind::object> capsule = mlirApiObjectToCapsule(src);
136187
value = mlirPythonCapsuleToContext(capsule->ptr());
@@ -153,7 +204,7 @@ struct type_caster<MlirDialectRegistry> {
153204
cleanup_list *cleanup) noexcept {
154205
nanobind::object capsule = nanobind::steal<nanobind::object>(
155206
mlirPythonDialectRegistryToCapsule(v));
156-
return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
207+
return mlir::python::irModule()
157208
.attr("DialectRegistry")
158209
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
159210
.release();
@@ -167,9 +218,7 @@ struct type_caster<MlirLocation> {
167218
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
168219
if (src.is_none()) {
169220
// Gets the current thread-bound context.
170-
src = nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
171-
.attr("Location")
172-
.attr("current");
221+
src = mlir::python::irModule().attr("Location").attr("current");
173222
}
174223
if (auto capsule = mlirApiObjectToCapsule(src)) {
175224
value = mlirPythonCapsuleToLocation(capsule->ptr());
@@ -181,7 +230,7 @@ struct type_caster<MlirLocation> {
181230
cleanup_list *cleanup) noexcept {
182231
nanobind::object capsule =
183232
nanobind::steal<nanobind::object>(mlirPythonLocationToCapsule(v));
184-
return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
233+
return mlir::python::irModule()
185234
.attr("Location")
186235
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
187236
.release();
@@ -203,7 +252,7 @@ struct type_caster<MlirModule> {
203252
cleanup_list *cleanup) noexcept {
204253
nanobind::object capsule =
205254
nanobind::steal<nanobind::object>(mlirPythonModuleToCapsule(v));
206-
return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
255+
return mlir::python::irModule()
207256
.attr("Module")
208257
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
209258
.release();
@@ -250,7 +299,7 @@ struct type_caster<MlirOperation> {
250299
return nanobind::none();
251300
nanobind::object capsule =
252301
nanobind::steal<nanobind::object>(mlirPythonOperationToCapsule(v));
253-
return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
302+
return mlir::python::irModule()
254303
.attr("Operation")
255304
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
256305
.release();
@@ -274,7 +323,7 @@ struct type_caster<MlirValue> {
274323
return nanobind::none();
275324
nanobind::object capsule =
276325
nanobind::steal<nanobind::object>(mlirPythonValueToCapsule(v));
277-
return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
326+
return mlir::python::irModule()
278327
.attr("Value")
279328
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
280329
.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
@@ -312,7 +361,7 @@ struct type_caster<MlirTypeID> {
312361
return nanobind::none();
313362
nanobind::object capsule =
314363
nanobind::steal<nanobind::object>(mlirPythonTypeIDToCapsule(v));
315-
return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
364+
return mlir::python::irModule()
316365
.attr("TypeID")
317366
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
318367
.release();
@@ -334,7 +383,7 @@ struct type_caster<MlirType> {
334383
cleanup_list *cleanup) noexcept {
335384
nanobind::object capsule =
336385
nanobind::steal<nanobind::object>(mlirPythonTypeToCapsule(t));
337-
return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
386+
return mlir::python::irModule()
338387
.attr("Type")
339388
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
340389
.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
@@ -453,11 +502,9 @@ class mlir_attribute_subclass : public pure_subclass {
453502
mlir_attribute_subclass(nanobind::handle scope, const char *attrClassName,
454503
IsAFunctionTy isaFunction,
455504
GetTypeIDFunctionTy getTypeIDFunction = nullptr)
456-
: mlir_attribute_subclass(
457-
scope, attrClassName, isaFunction,
458-
nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
459-
.attr("Attribute"),
460-
getTypeIDFunction) {}
505+
: mlir_attribute_subclass(scope, attrClassName, isaFunction,
506+
irModule().attr("Attribute"),
507+
getTypeIDFunction) {}
461508

462509
/// Subclasses with a provided mlir.ir.Attribute super-class. This must
463510
/// be used if the subclass is being defined in the same extension module
@@ -540,11 +587,8 @@ class mlir_type_subclass : public pure_subclass {
540587
mlir_type_subclass(nanobind::handle scope, const char *typeClassName,
541588
IsAFunctionTy isaFunction,
542589
GetTypeIDFunctionTy getTypeIDFunction = nullptr)
543-
: mlir_type_subclass(
544-
scope, typeClassName, isaFunction,
545-
nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
546-
.attr("Type"),
547-
getTypeIDFunction) {}
590+
: mlir_type_subclass(scope, typeClassName, isaFunction,
591+
irModule().attr("Type"), getTypeIDFunction) {}
548592

549593
/// Subclasses with a provided mlir.ir.Type super-class. This must
550594
/// be used if the subclass is being defined in the same extension module
@@ -631,10 +675,8 @@ class mlir_value_subclass : public pure_subclass {
631675
/// Subclasses by looking up the super-class dynamically.
632676
mlir_value_subclass(nanobind::handle scope, const char *valueClassName,
633677
IsAFunctionTy isaFunction)
634-
: mlir_value_subclass(
635-
scope, valueClassName, isaFunction,
636-
nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
637-
.attr("Value")) {}
678+
: mlir_value_subclass(scope, valueClassName, isaFunction,
679+
irModule().attr("Value")) {}
638680

639681
/// Subclasses with a provided mlir.ir.Value super-class. This must
640682
/// be used if the subclass is being defined in the same extension module

0 commit comments

Comments
 (0)