Skip to content

Commit aae4b2d

Browse files
committed
Updated locks and added docs
1 parent c9bf7d2 commit aae4b2d

File tree

6 files changed

+129
-78
lines changed

6 files changed

+129
-78
lines changed

mlir/docs/Bindings/Python.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,3 +1187,43 @@ or nanobind and
11871187
utilities to connect to the rest of Python API. The bindings can be located in a
11881188
separate module or in the same module as attributes and types, and
11891189
loaded along with the dialect.
1190+
1191+
## Free-threading (No-GIL) support
1192+
1193+
Free-threading or no-GIL support refers to CPython interpreter (>=3.13) with Global Interpreter Lock made optional. For details on the topic, please check [PEP-703](https://peps.python.org/pep-0703/) and the this [link](https://py-free-threading.github.io/).
1194+
1195+
MLIR Python PyBind11 bindings are made free-threading compatible with exceptions (discussed below) in the following sense: it is safe to work in multiple threads with **independent** contexts/modules. Below we show an example code of safe usage:
1196+
1197+
```python
1198+
# python3.13t example.py
1199+
import concurrent.futures
1200+
1201+
import mlir.dialects.arith as arith
1202+
from mlir.ir import Context, Location, Module, IntegerType, InsertionPoint
1203+
1204+
1205+
def func(py_value):
1206+
with Context() as ctx:
1207+
module = Module.create(loc=Location.file("foo.txt", 0, 0))
1208+
1209+
dtype = IntegerType.get_signless(64)
1210+
with InsertionPoint(module.body), Location.name("a"):
1211+
arith.constant(dtype, py_value)
1212+
1213+
return module
1214+
1215+
1216+
num_workers = 8
1217+
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
1218+
futures = []
1219+
for i in range(num_workers):
1220+
futures.append(executor.submit(func, i))
1221+
assert len(list(f.result() for f in futures)) == num_workers
1222+
```
1223+
1224+
The exceptions to the free-threading compatibility:
1225+
- registration methods and decorators, e.g. `register_dialect`, `register_operation`, `register_dialect`, `register_attribute_builder`, ...
1226+
- `ctypes` is unsafe
1227+
- IR printing is unsafe
1228+
1229+
For details, please see the list of xfailed tests in `mlir/test/python/multithreaded_tests.py`.

mlir/lib/Bindings/Python/Globals.h

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,6 @@ class PyGlobals {
3636
return *instance;
3737
}
3838

39-
template<typename F>
40-
static inline auto withInstance(const F& cb) -> decltype(cb(get())) {
41-
auto &instance = get();
42-
#ifdef Py_GIL_DISABLED
43-
auto &lock = getLock();
44-
PyMutex_Lock(&lock);
45-
#endif
46-
auto result = cb(instance);
47-
#ifdef Py_GIL_DISABLED
48-
PyMutex_Unlock(&lock);
49-
#endif
50-
return result;
51-
}
52-
5339
/// Get and set the list of parent modules to search for dialect
5440
/// implementation classes.
5541
std::vector<std::string> &getDialectSearchPrefixes() {

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,7 @@ py::object classmethod(Func f, Args... args) {
198198
static py::object
199199
createCustomDialectWrapper(const std::string &dialectNamespace,
200200
py::object dialectDescriptor) {
201-
auto dialectClass = PyGlobals::withInstance([&](PyGlobals& instance) {
202-
return instance.lookupDialectClass(dialectNamespace);
203-
});
201+
auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
204202
if (!dialectClass) {
205203
// Use the base class.
206204
return py::cast(PyDialect(std::move(dialectDescriptor)));
@@ -309,20 +307,15 @@ struct PyAttrBuilderMap {
309307
return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
310308
}
311309
static py::function dunderGetItemNamed(const std::string &attributeKind) {
312-
auto builder = PyGlobals::withInstance([&](PyGlobals& instance) {
313-
return instance.lookupAttributeBuilder(attributeKind);
314-
});
310+
auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
315311
if (!builder)
316312
throw py::key_error(attributeKind);
317313
return *builder;
318314
}
319315
static void dunderSetItemNamed(const std::string &attributeKind,
320316
py::function func, bool replace) {
321-
PyGlobals::withInstance([&](PyGlobals& instance) {
322-
instance.registerAttributeBuilder(attributeKind, std::move(func),
323-
replace);
324-
return 0;
325-
});
317+
PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
318+
replace);
326319
}
327320

328321
static void bind(py::module &m) {
@@ -1613,10 +1606,8 @@ py::object PyOperation::createOpView() {
16131606
checkValid();
16141607
MlirIdentifier ident = mlirOperationGetName(get());
16151608
MlirStringRef identStr = mlirIdentifierStr(ident);
1616-
auto operationCls = PyGlobals::withInstance([&](PyGlobals& instance){
1617-
return instance.lookupOperationClass(
1618-
StringRef(identStr.data, identStr.length));
1619-
});
1609+
auto operationCls = PyGlobals::get().lookupOperationClass(
1610+
StringRef(identStr.data, identStr.length));
16201611
if (operationCls)
16211612
return PyOpView::constructDerived(*operationCls, *getRef().get());
16221613
return py::cast(PyOpView(getRef().getObject()));
@@ -2067,9 +2058,7 @@ pybind11::object PyValue::maybeDownCast() {
20672058
assert(!mlirTypeIDIsNull(mlirTypeID) &&
20682059
"mlirTypeID was expected to be non-null.");
20692060
std::optional<pybind11::function> valueCaster =
2070-
PyGlobals::withInstance([&](PyGlobals& instance) {
2071-
return instance.lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
2072-
});
2061+
PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
20732062
// py::return_value_policy::move means use std::move to move the return value
20742063
// contents into a new instance that will be owned by Python.
20752064
py::object thisObj = py::cast(this, py::return_value_policy::move);
@@ -3548,10 +3537,8 @@ void mlir::python::populateIRCore(py::module &m) {
35483537
assert(!mlirTypeIDIsNull(mlirTypeID) &&
35493538
"mlirTypeID was expected to be non-null.");
35503539
std::optional<pybind11::function> typeCaster =
3551-
PyGlobals::withInstance([&](PyGlobals& instance){
3552-
return instance.lookupTypeCaster(mlirTypeID,
3553-
mlirAttributeGetDialect(self));
3554-
});
3540+
PyGlobals::get().lookupTypeCaster(mlirTypeID,
3541+
mlirAttributeGetDialect(self));
35553542
if (!typeCaster)
35563543
return py::cast(self);
35573544
return typeCaster.value()(self);
@@ -3649,10 +3636,8 @@ void mlir::python::populateIRCore(py::module &m) {
36493636
assert(!mlirTypeIDIsNull(mlirTypeID) &&
36503637
"mlirTypeID was expected to be non-null.");
36513638
std::optional<pybind11::function> typeCaster =
3652-
PyGlobals::withInstance([&](PyGlobals& instance){
3653-
return instance.lookupTypeCaster(mlirTypeID,
3654-
mlirTypeGetDialect(self));
3655-
});
3639+
PyGlobals::get().lookupTypeCaster(mlirTypeID,
3640+
mlirTypeGetDialect(self));
36563641
if (!typeCaster)
36573642
return py::cast(self);
36583643
return typeCaster.value()(self);

mlir/lib/Bindings/Python/IRModule.cpp

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,21 @@ PyGlobals::PyGlobals() {
3737
PyGlobals::~PyGlobals() { instance = nullptr; }
3838

3939
bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
40-
if (loadedDialectModules.contains(dialectNamespace))
40+
auto already_loaded = false;
41+
#ifdef Py_GIL_DISABLED
42+
{
43+
auto &lock = getLock();
44+
PyMutex_Lock(&lock);
45+
#endif
46+
47+
already_loaded = loadedDialectModules.contains(dialectNamespace);
48+
49+
#ifdef Py_GIL_DISABLED
50+
PyMutex_Unlock(&lock);
51+
}
52+
#endif
53+
54+
if (already_loaded)
4155
return true;
4256
// Since re-entrancy is possible, make a copy of the search prefixes.
4357
std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
@@ -59,9 +73,24 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
5973

6074
if (loaded.is_none())
6175
return false;
62-
// Note: Iterator cannot be shared from prior to loading, since re-entrancy
63-
// may have occurred, which may do anything.
64-
loadedDialectModules.insert(dialectNamespace);
76+
77+
// We should use a lock in free-threading as loadDialectModule can be implicitly called by
78+
// python functions executed by in multiple threads context (e.g. lookupValueCaster).
79+
#ifdef Py_GIL_DISABLED
80+
{
81+
auto &lock = getLock();
82+
PyMutex_Lock(&lock);
83+
#endif
84+
85+
// Note: Iterator cannot be shared from prior to loading, since re-entrancy
86+
// may have occurred, which may do anything.
87+
loadedDialectModules.insert(dialectNamespace);
88+
89+
#ifdef Py_GIL_DISABLED
90+
PyMutex_Unlock(&lock);
91+
}
92+
#endif
93+
6594
return true;
6695
}
6796

mlir/lib/Bindings/Python/MainModule.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,7 @@ PYBIND11_MODULE(_mlir, m, py::mod_gil_not_used()) {
7474
std::string operationName =
7575
opClass.attr("OPERATION_NAME").cast<std::string>();
7676

77-
// Use PyGlobals::withInstance instead of PyGlobals::get()
78-
// to prevent data race in multi-threaded context
79-
// Error raised in ir/opeation.py testKnownOpView test
80-
PyGlobals::withInstance([&](PyGlobals& instance) {
81-
instance.registerOperationImpl(operationName, opClass, replace);
82-
return 0;
83-
});
77+
PyGlobals::get().registerOperationImpl(operationName, opClass, replace);
8478
// Dict-stuff the new opClass by name onto the dialect class.
8579
py::object opClassName = opClass.attr("__name__");
8680
dialectClass.attr(opClassName) = opClass;

0 commit comments

Comments
 (0)