Skip to content

Commit 65a1006

Browse files
committed
[skip-ci] Added lock on PyGlobals::get and PyMlirContext liveContexts
WIP on adding multithreaded_tests
1 parent df311a8 commit 65a1006

File tree

9 files changed

+298
-31
lines changed

9 files changed

+298
-31
lines changed

mlir/docs/Bindings/Python.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,7 +1035,7 @@ class ConstantOp(_ods_ir.OpView):
10351035
...
10361036
```
10371037

1038-
expects `value` to be a `TypedAttr` (e.g., `IntegerAttr` or `FloatAttr`).
1038+
expects `value` to be a `TypedAttr` (e.g., `IntegerAttr` or `FloatAttr`).
10391039
Thus, a natural extension is a builder that accepts a MLIR type and a Python value and instantiates the appropriate `TypedAttr`:
10401040

10411041
```python
@@ -1181,9 +1181,9 @@ make the passes available along with the dialect.
11811181
Dialect functionality other than IR objects or passes, such as helper functions,
11821182
can be exposed to Python similarly to attributes and types. C API is expected to
11831183
exist for this functionality, which can then be wrapped using pybind11 and
1184-
`[include/mlir/Bindings/Python/PybindAdaptors.h](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/PybindAdaptors.h)`,
1184+
[`include/mlir/Bindings/Python/PybindAdaptors.h`](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/PybindAdaptors.h),
11851185
or nanobind and
1186-
`[include/mlir/Bindings/Python/NanobindAdaptors.h](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h)`
1186+
[`include/mlir/Bindings/Python/NanobindAdaptors.h`](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h)
11871187
utilities to connect to the rest of Python API. The bindings can be located in a
11881188
separate module or in the same module as attributes and types, and
11891189
loaded along with the dialect.

mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
#include "Standalone-c/Dialects.h"
1313
#include "mlir/Bindings/Python/PybindAdaptors.h"
1414

15+
namespace py = pybind11;
16+
1517
using namespace mlir::python::adaptors;
1618

17-
PYBIND11_MODULE(_standaloneDialectsPybind11, m) {
19+
PYBIND11_MODULE(_standaloneDialects, m, py::mod_gil_not_used()) {
1820
//===--------------------------------------------------------------------===//
1921
// standalone dialect
2022
//===--------------------------------------------------------------------===//

mlir/lib/Bindings/Python/Globals.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,20 @@ class PyGlobals {
3636
return *instance;
3737
}
3838

39+
template<typename F>
40+
static inline auto withInstance(const F& cb) -> decltype(cb(get())) {
41+
auto &instance = get();
42+
#ifdef Py_GIL_DISABLED
43+
auto &lock = getLock();
44+
PyMutex_Lock(&lock);
45+
#endif
46+
auto result = cb(instance);
47+
#ifdef Py_GIL_DISABLED
48+
PyMutex_Unlock(&lock);
49+
#endif
50+
return result;
51+
}
52+
3953
/// Get and set the list of parent modules to search for dialect
4054
/// implementation classes.
4155
std::vector<std::string> &getDialectSearchPrefixes() {
@@ -125,6 +139,14 @@ class PyGlobals {
125139
/// Set of dialect namespaces that we have attempted to import implementation
126140
/// modules for.
127141
llvm::StringSet<> loadedDialectModules;
142+
143+
#ifdef Py_GIL_DISABLED
144+
static PyMutex &getLock() {
145+
static PyMutex lock;
146+
return lock;
147+
}
148+
#endif
149+
128150
};
129151

130152
} // namespace python

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,9 @@ py::object classmethod(Func f, Args... args) {
198198
static py::object
199199
createCustomDialectWrapper(const std::string &dialectNamespace,
200200
py::object dialectDescriptor) {
201-
auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
201+
auto dialectClass = PyGlobals::withInstance([&](PyGlobals& instance) {
202+
return instance.lookupDialectClass(dialectNamespace);
203+
});
202204
if (!dialectClass) {
203205
// Use the base class.
204206
return py::cast(PyDialect(std::move(dialectDescriptor)));
@@ -601,16 +603,23 @@ class PyOpOperandIterator {
601603

602604
PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
603605
py::gil_scoped_acquire acquire;
604-
auto &liveContexts = getLiveContexts();
605-
liveContexts[context.ptr] = this;
606+
withLiveContexts([&](LiveContextMap& liveContexts) {
607+
liveContexts[context.ptr] = this;
608+
return this;
609+
});
606610
}
607611

608612
PyMlirContext::~PyMlirContext() {
609613
// Note that the only public way to construct an instance is via the
610614
// forContext method, which always puts the associated handle into
611615
// liveContexts.
612616
py::gil_scoped_acquire acquire;
613-
getLiveContexts().erase(context.ptr);
617+
618+
withLiveContexts([&](LiveContextMap& liveContexts) {
619+
liveContexts.erase(context.ptr);
620+
return this;
621+
});
622+
614623
mlirContextDestroy(context);
615624
}
616625

@@ -632,27 +641,32 @@ PyMlirContext *PyMlirContext::createNewContextForInit() {
632641

633642
PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
634643
py::gil_scoped_acquire acquire;
635-
auto &liveContexts = getLiveContexts();
636-
auto it = liveContexts.find(context.ptr);
637-
if (it == liveContexts.end()) {
638-
// Create.
639-
PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
640-
py::object pyRef = py::cast(unownedContextWrapper);
641-
assert(pyRef && "cast to py::object failed");
642-
liveContexts[context.ptr] = unownedContextWrapper;
643-
return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
644-
}
645-
// Use existing.
646-
py::object pyRef = py::cast(it->second);
647-
return PyMlirContextRef(it->second, std::move(pyRef));
644+
return withLiveContexts([&](LiveContextMap& liveContexts) {
645+
auto it = liveContexts.find(context.ptr);
646+
if (it == liveContexts.end()) {
647+
// Create.
648+
PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
649+
py::object pyRef = py::cast(unownedContextWrapper);
650+
assert(pyRef && "cast to py::object failed");
651+
liveContexts[context.ptr] = unownedContextWrapper;
652+
return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
653+
}
654+
// Use existing.
655+
py::object pyRef = py::cast(it->second);
656+
return PyMlirContextRef(it->second, std::move(pyRef));
657+
});
648658
}
649659

650660
PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
651661
static LiveContextMap liveContexts;
652662
return liveContexts;
653663
}
654664

655-
size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
665+
size_t PyMlirContext::getLiveCount() {
666+
return withLiveContexts([&](LiveContextMap& liveContexts) {
667+
return liveContexts.size();
668+
});
669+
}
656670

657671
size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
658672

@@ -1556,8 +1570,10 @@ py::object PyOperation::createOpView() {
15561570
checkValid();
15571571
MlirIdentifier ident = mlirOperationGetName(get());
15581572
MlirStringRef identStr = mlirIdentifierStr(ident);
1559-
auto operationCls = PyGlobals::get().lookupOperationClass(
1560-
StringRef(identStr.data, identStr.length));
1573+
auto operationCls = PyGlobals::withInstance([&](PyGlobals& instance){
1574+
return instance.lookupOperationClass(
1575+
StringRef(identStr.data, identStr.length));
1576+
});
15611577
if (operationCls)
15621578
return PyOpView::constructDerived(*operationCls, *getRef().get());
15631579
return py::cast(PyOpView(getRef().getObject()));
@@ -2008,7 +2024,9 @@ pybind11::object PyValue::maybeDownCast() {
20082024
assert(!mlirTypeIDIsNull(mlirTypeID) &&
20092025
"mlirTypeID was expected to be non-null.");
20102026
std::optional<pybind11::function> valueCaster =
2011-
PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
2027+
PyGlobals::withInstance([&](PyGlobals& instance) {
2028+
return instance.lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
2029+
});
20122030
// py::return_value_policy::move means use std::move to move the return value
20132031
// contents into a new instance that will be owned by Python.
20142032
py::object thisObj = py::cast(this, py::return_value_policy::move);
@@ -3487,8 +3505,10 @@ void mlir::python::populateIRCore(py::module &m) {
34873505
assert(!mlirTypeIDIsNull(mlirTypeID) &&
34883506
"mlirTypeID was expected to be non-null.");
34893507
std::optional<pybind11::function> typeCaster =
3490-
PyGlobals::get().lookupTypeCaster(mlirTypeID,
3491-
mlirAttributeGetDialect(self));
3508+
PyGlobals::withInstance([&](PyGlobals& instance){
3509+
return instance.lookupTypeCaster(mlirTypeID,
3510+
mlirAttributeGetDialect(self));
3511+
});
34923512
if (!typeCaster)
34933513
return py::cast(self);
34943514
return typeCaster.value()(self);
@@ -3585,9 +3605,11 @@ void mlir::python::populateIRCore(py::module &m) {
35853605
MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
35863606
assert(!mlirTypeIDIsNull(mlirTypeID) &&
35873607
"mlirTypeID was expected to be non-null.");
3588-
std::optional<pybind11::function> typeCaster =
3589-
PyGlobals::get().lookupTypeCaster(mlirTypeID,
3608+
std::optional<pybind11::function> typeCaster =
3609+
PyGlobals::withInstance([&](PyGlobals& instance){
3610+
return instance.lookupTypeCaster(mlirTypeID,
35903611
mlirTypeGetDialect(self));
3612+
});
35913613
if (!typeCaster)
35923614
return py::cast(self);
35933615
return typeCaster.value()(self);

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,27 @@ class PyMlirContext {
263263
using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
264264
static LiveContextMap &getLiveContexts();
265265

266+
#ifdef Py_GIL_DISABLED
267+
static PyMutex &getLock() {
268+
static PyMutex lock;
269+
return lock;
270+
}
271+
#endif
272+
273+
template<typename F>
274+
static inline auto withLiveContexts(const F& cb) -> decltype(cb(getLiveContexts())) {
275+
auto &liveContexts = getLiveContexts();
276+
#ifdef Py_GIL_DISABLED
277+
auto &lock = getLock();
278+
PyMutex_Lock(&lock);
279+
#endif
280+
auto result = cb(liveContexts);
281+
#ifdef Py_GIL_DISABLED
282+
PyMutex_Unlock(&lock);
283+
#endif
284+
return result;
285+
}
286+
266287
// Interns all live modules associated with this context. Modules tracked
267288
// in this map are valid. When a module is invalidated, it is removed
268289
// from this map, and while it still exists as an instance, any

mlir/test/python/execution_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def callback(a):
306306
log(arr)
307307

308308
with Context():
309-
# The module takes a subview of the argument memref, casts it to an unranked memref and
309+
# The module takes a subview of the argument memref, casts it to an unranked memref and
310310
# calls the callback with it.
311311
module = Module.parse(
312312
r"""

mlir/test/python/lib/PythonTestModulePybind11.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ static bool mlirTypeIsARankedIntegerTensor(MlirType t) {
2323
mlirTypeIsAInteger(mlirShapedTypeGetElementType(t));
2424
}
2525

26-
PYBIND11_MODULE(_mlirPythonTestPybind11, m) {
26+
PYBIND11_MODULE(_mlirPythonTest, m, py::mod_gil_not_used()) {
2727
m.def(
2828
"register_python_test_dialect",
2929
[](MlirContext context, bool load) {

0 commit comments

Comments
 (0)