Skip to content

Commit 0683720

Browse files
committed
Updated locks and added docs
1 parent 303c87e commit 0683720

File tree

6 files changed

+109
-74
lines changed

6 files changed

+109
-74
lines changed

mlir/docs/Bindings/Python.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,3 +1187,43 @@ or nanobind and
11871187
utilities to connect to the rest of Python API. The bindings can be located in a
11881188
separate module or in the same module as attributes and types, and
11891189
loaded along with the dialect.
1190+
1191+
## Free-threading (No-GIL) support
1192+
1193+
Free-threading or no-GIL support refers to CPython interpreter (>=3.13) with Global Interpreter Lock made optional. For details on the topic, please check [PEP-703](https://peps.python.org/pep-0703/) and the this [link](https://py-free-threading.github.io/).
1194+
1195+
MLIR Python PyBind11 bindings are made free-threading compatible with exceptions (discussed below) in the following sense: it is safe to work in multiple threads with **independent** contexts/modules. Below we show an example code of safe usage:
1196+
1197+
```python
1198+
# python3.13t example.py
1199+
import concurrent.futures
1200+
1201+
import mlir.dialects.arith as arith
1202+
from mlir.ir import Context, Location, Module, IntegerType, InsertionPoint
1203+
1204+
1205+
def func(py_value):
1206+
with Context() as ctx:
1207+
module = Module.create(loc=Location.file("foo.txt", 0, 0))
1208+
1209+
dtype = IntegerType.get_signless(64)
1210+
with InsertionPoint(module.body), Location.name("a"):
1211+
arith.constant(dtype, py_value)
1212+
1213+
return module
1214+
1215+
1216+
num_workers = 8
1217+
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
1218+
futures = []
1219+
for i in range(num_workers):
1220+
futures.append(executor.submit(func, i))
1221+
assert len(list(f.result() for f in futures)) == num_workers
1222+
```
1223+
1224+
The exceptions to the free-threading compatibility:
1225+
- registration methods and decorators, e.g. `register_dialect`, `register_operation`, `register_dialect`, `register_attribute_builder`, ...
1226+
- `ctypes` is unsafe
1227+
- IR printing is unsafe
1228+
1229+
For details, please see the list of xfailed tests in `mlir/test/python/multithreaded_tests.py`.

mlir/lib/Bindings/Python/Globals.h

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,6 @@ class PyGlobals {
3636
return *instance;
3737
}
3838

39-
template<typename F>
40-
static inline auto withInstance(const F& cb) -> decltype(cb(get())) {
41-
auto &instance = get();
42-
#ifdef Py_GIL_DISABLED
43-
auto &lock = getLock();
44-
PyMutex_Lock(&lock);
45-
#endif
46-
auto result = cb(instance);
47-
#ifdef Py_GIL_DISABLED
48-
PyMutex_Unlock(&lock);
49-
#endif
50-
return result;
51-
}
52-
5339
/// Get and set the list of parent modules to search for dialect
5440
/// implementation classes.
5541
std::vector<std::string> &getDialectSearchPrefixes() {

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,7 @@ py::object classmethod(Func f, Args... args) {
198198
static py::object
199199
createCustomDialectWrapper(const std::string &dialectNamespace,
200200
py::object dialectDescriptor) {
201-
auto dialectClass = PyGlobals::withInstance([&](PyGlobals& instance) {
202-
return instance.lookupDialectClass(dialectNamespace);
203-
});
201+
auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
204202
if (!dialectClass) {
205203
// Use the base class.
206204
return py::cast(PyDialect(std::move(dialectDescriptor)));
@@ -309,20 +307,15 @@ struct PyAttrBuilderMap {
309307
return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
310308
}
311309
static py::function dunderGetItemNamed(const std::string &attributeKind) {
312-
auto builder = PyGlobals::withInstance([&](PyGlobals& instance) {
313-
return instance.lookupAttributeBuilder(attributeKind);
314-
});
310+
auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
315311
if (!builder)
316312
throw py::key_error(attributeKind);
317313
return *builder;
318314
}
319315
static void dunderSetItemNamed(const std::string &attributeKind,
320316
py::function func, bool replace) {
321-
PyGlobals::withInstance([&](PyGlobals& instance) {
322-
instance.registerAttributeBuilder(attributeKind, std::move(func),
323-
replace);
324-
return 0;
325-
});
317+
PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
318+
replace);
326319
}
327320

328321
static void bind(py::module &m) {
@@ -1613,10 +1606,8 @@ py::object PyOperation::createOpView() {
16131606
checkValid();
16141607
MlirIdentifier ident = mlirOperationGetName(get());
16151608
MlirStringRef identStr = mlirIdentifierStr(ident);
1616-
auto operationCls = PyGlobals::withInstance([&](PyGlobals& instance){
1617-
return instance.lookupOperationClass(
1618-
StringRef(identStr.data, identStr.length));
1619-
});
1609+
auto operationCls = PyGlobals::get().lookupOperationClass(
1610+
StringRef(identStr.data, identStr.length));
16201611
if (operationCls)
16211612
return PyOpView::constructDerived(*operationCls, *getRef().get());
16221613
return py::cast(PyOpView(getRef().getObject()));
@@ -2067,9 +2058,7 @@ pybind11::object PyValue::maybeDownCast() {
20672058
assert(!mlirTypeIDIsNull(mlirTypeID) &&
20682059
"mlirTypeID was expected to be non-null.");
20692060
std::optional<pybind11::function> valueCaster =
2070-
PyGlobals::withInstance([&](PyGlobals& instance) {
2071-
return instance.lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
2072-
});
2061+
PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
20732062
// py::return_value_policy::move means use std::move to move the return value
20742063
// contents into a new instance that will be owned by Python.
20752064
py::object thisObj = py::cast(this, py::return_value_policy::move);
@@ -3548,10 +3537,8 @@ void mlir::python::populateIRCore(py::module &m) {
35483537
assert(!mlirTypeIDIsNull(mlirTypeID) &&
35493538
"mlirTypeID was expected to be non-null.");
35503539
std::optional<pybind11::function> typeCaster =
3551-
PyGlobals::withInstance([&](PyGlobals& instance){
3552-
return instance.lookupTypeCaster(mlirTypeID,
3553-
mlirAttributeGetDialect(self));
3554-
});
3540+
PyGlobals::get().lookupTypeCaster(mlirTypeID,
3541+
mlirAttributeGetDialect(self));
35553542
if (!typeCaster)
35563543
return py::cast(self);
35573544
return typeCaster.value()(self);
@@ -3649,10 +3636,8 @@ void mlir::python::populateIRCore(py::module &m) {
36493636
assert(!mlirTypeIDIsNull(mlirTypeID) &&
36503637
"mlirTypeID was expected to be non-null.");
36513638
std::optional<pybind11::function> typeCaster =
3652-
PyGlobals::withInstance([&](PyGlobals& instance){
3653-
return instance.lookupTypeCaster(mlirTypeID,
3654-
mlirTypeGetDialect(self));
3655-
});
3639+
PyGlobals::get().lookupTypeCaster(mlirTypeID,
3640+
mlirTypeGetDialect(self));
36563641
if (!typeCaster)
36573642
return py::cast(self);
36583643
return typeCaster.value()(self);

mlir/lib/Bindings/Python/IRModule.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,22 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
5959

6060
if (loaded.is_none())
6161
return false;
62+
63+
// We should use a lock in free-threading as loadDialectModule can be implicitly called by
64+
// python functions executed by in multiple threads context (e.g. lookupValueCaster).
65+
#ifdef Py_GIL_DISABLED
66+
auto &lock = getLock();
67+
PyMutex_Lock(&lock);
68+
#endif
69+
6270
// Note: Iterator cannot be shared from prior to loading, since re-entrancy
6371
// may have occurred, which may do anything.
6472
loadedDialectModules.insert(dialectNamespace);
73+
74+
#ifdef Py_GIL_DISABLED
75+
PyMutex_Unlock(&lock);
76+
#endif
77+
6578
return true;
6679
}
6780

mlir/lib/Bindings/Python/MainModule.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,7 @@ PYBIND11_MODULE(_mlir, m, py::mod_gil_not_used()) {
7474
std::string operationName =
7575
opClass.attr("OPERATION_NAME").cast<std::string>();
7676

77-
// Use PyGlobals::withInstance instead of PyGlobals::get()
78-
// to prevent data race in multi-threaded context
79-
// Error raised in ir/opeation.py testKnownOpView test
80-
PyGlobals::withInstance([&](PyGlobals& instance) {
81-
instance.registerOperationImpl(operationName, opClass, replace);
82-
return 0;
83-
});
77+
PyGlobals::get().registerOperationImpl(operationName, opClass, replace);
8478
// Dict-stuff the new opClass by name onto the dialect class.
8579
py::object opClassName = opClass.attr("__name__");
8680
dialectClass.attr(opClassName) = opClass;

mlir/test/python/multithreaded_tests.py

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
2) Tests generation: we use existing tests: test/python/ir/*.py,
1717
test/python/dialects/*.py, etc to generate multi-threaded tests.
1818
In details, we perform the following:
19-
a) we define a list of source tests to be used to generate multi-threaded tests, see `test_modules`.
19+
a) we define a list of source tests to be used to generate multi-threaded tests, see `TEST_MODULES`.
2020
b) we define `TestAllMultiThreaded` class and add existing tests to the class. See `add_existing_tests` method.
2121
c) for each test file, we copy and modify it: test/python/ir/affine_expr.py -> /tmp/ir/affine_expr.py.
2222
In order to import the test file as python module, we remove all executing functions, like
@@ -376,10 +376,9 @@ def run_construct_and_print_in_module(f):
376376
return f
377377

378378

379-
test_modules = [
379+
TEST_MODULES = [
380380
("execution_engine", run), # Fail,
381381
("pass_manager", run), # Fail
382-
383382
("dialects/affine", run_with_insertion_point_v2), # Pass
384383
("dialects/func", run_with_insertion_point_v2), # Pass
385384
("dialects/arith_dialect", run), # Pass
@@ -410,32 +409,39 @@ def run_construct_and_print_in_module(f):
410409
("dialects/transform_bufferization_ext", run_with_insertion_point_v2), # Pass
411410
# ("dialects/transform_extras", ), # Needs a more complicated execution schema
412411
("dialects/transform_gpu_ext", run_transform_tensor_ext), # Pass
413-
("dialects/transform_interpreter", run_with_context_and_location, ["print_", "transform_options", "failed", "include"]), # Fail
414-
("dialects/transform_loop_ext", run_with_insertion_point_v2, ["loopOutline"]), # Pass
412+
(
413+
"dialects/transform_interpreter",
414+
run_with_context_and_location,
415+
["print_", "transform_options", "failed", "include"],
416+
), # Fail
417+
(
418+
"dialects/transform_loop_ext",
419+
run_with_insertion_point_v2,
420+
["loopOutline"],
421+
), # Pass
415422
("dialects/transform_memref_ext", run_with_insertion_point_v2), # Pass
416423
("dialects/transform_nvgpu_ext", run_with_insertion_point_v2), # Pass
417424
("dialects/transform_sparse_tensor_ext", run_transform_tensor_ext), # Pass
418425
("dialects/transform_structured_ext", run_transform_structured_ext), # Pass
419426
("dialects/transform_tensor_ext", run_transform_tensor_ext), # Pass
420-
("dialects/transform_vector_ext", run_apply_patterns, ["configurable_patterns"]), # Pass
427+
(
428+
"dialects/transform_vector_ext",
429+
run_apply_patterns,
430+
["configurable_patterns"],
431+
), # Pass
421432
("dialects/transform", run_with_insertion_point_v3), # Pass
422433
("dialects/vector", run_with_context_and_location), # Pass
423-
424434
("dialects/gpu/dialect", run_with_context_and_location), # Pass
425435
("dialects/gpu/module-to-binary-nvvm", run_with_context_and_location), # Pass
426436
("dialects/gpu/module-to-binary-rocdl", run_with_context_and_location), # Fail
427-
428437
("dialects/linalg/ops", run), # Pass
429438
# TO ADD: No proper tests in this dialects/linalg/opsdsl/*
430439
# ("dialects/linalg/opsdsl/*", ...), #
431-
432440
("dialects/sparse_tensor/dialect", run), # Pass
433441
("dialects/sparse_tensor/passes", run), # Pass
434-
435442
("integration/dialects/pdl", run_construct_and_print_in_module), # Pass
436443
("integration/dialects/transform", run_construct_and_print_in_module), # Pass
437444
("integration/dialects/linalg/opsrun", run), # Fail
438-
439445
("ir/affine_expr", run), # Pass
440446
("ir/affine_map", run), # Pass
441447
("ir/array_attributes", run), # Pass
@@ -456,16 +462,18 @@ def run_construct_and_print_in_module(f):
456462
("ir/value", run), # Pass
457463
]
458464

459-
tests_to_skip = [
465+
TESTS_TO_SKIP = [
460466
"test_execution_engine__testNanoTime_multi_threaded", # testNanoTime can't run in multiple threads, even with GIL
461467
"test_execution_engine__testSharedLibLoad_multi_threaded", # testSharedLibLoad can't run in multiple threads, even with GIL
462468
"test_dialects_arith_dialect__testArithValue_multi_threaded", # RuntimeError: Value caster is already registered: <class 'dialects/arith_dialect.testArithValue.<locals>.ArithValue'>, even with GIL
463469
"test_ir_dialects__testAppendPrefixSearchPath_multi_threaded", # PyGlobals::setDialectSearchPrefixes is not thread-safe, even with GIL. Strange usage of static PyGlobals vs python exposed _cext.globals
464470
"test_ir_value__testValueCasters_multi_threaded_multi_threaded", # RuntimeError: Value caster is already registered: <function testValueCasters.<locals>.dont_cast_int, even with GIL
471+
"test_ir_operation_testKnownOpView_multi_threaded_multi_threaded", # uses register_operation method in the test
472+
"test_dialects_transform_structured_ext__testMatchInterfaceEnumReplaceAttributeBuilder_multi_threaded", # uses register_attribute_builder method in the test
465473
]
466474

467475

468-
tests_to_xfail = [
476+
TESTS_TO_XFAIL = [
469477
# execution_engine tests, ctypes related data-races, may be false-positive as libffi was not instrumented with TSAN
470478
"test_execution_engine__testBF16Memref_multi_threaded",
471479
"test_execution_engine__testBasicCallback_multi_threaded",
@@ -476,31 +484,28 @@ def run_construct_and_print_in_module(f):
476484
"test_execution_engine__testF8E5M2Memref_multi_threaded",
477485
"test_execution_engine__testInvalidModule_multi_threaded",
478486
"test_execution_engine__testInvokeFloatAdd_multi_threaded",
487+
"test_execution_engine__testInvokeVoid_multi_threaded",
479488
"test_execution_engine__testMemrefAdd_multi_threaded",
480489
"test_execution_engine__testRankedMemRefCallback_multi_threaded",
481490
"test_execution_engine__testRankedMemRefWithOffsetCallback_multi_threaded",
482491
"test_execution_engine__testUnrankedMemRefCallback_multi_threaded",
483492
"test_execution_engine__testUnrankedMemRefWithOffsetCallback_multi_threaded",
484-
485493
# pass_manager tests
486494
"test_pass_manager__testPrintIrAfterAll_multi_threaded", # IRPrinterInstrumentation::runAfterPass is not thread-safe
487495
"test_pass_manager__testPrintIrBeforeAndAfterAll_multi_threaded", # IRPrinterInstrumentation::runBeforePass is not thread-safe
488496
"test_pass_manager__testPrintIrLargeLimitElements_multi_threaded", # IRPrinterInstrumentation::runAfterPass is not thread-safe
489497
"test_pass_manager__testPrintIrTree_multi_threaded", # IRPrinterInstrumentation::runAfterPass is not thread-safe
490498
"test_pass_manager__testRunPipeline_multi_threaded", # PrintOpStatsPass::printSummary is not thread-safe
491-
492499
# dialects tests
493500
"test_dialects_memref__testSubViewOpInferReturnTypeExtensiveSlicing_multi_threaded", # Related to ctypes data races
494501
"test_dialects_transform_interpreter__include_multi_threaded", # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) is not thread-safe
495502
"test_dialects_transform_interpreter__print_other_multi_threaded", # Fatal Python error: Aborted or mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) is not thread-safe
496503
"test_dialects_transform_interpreter__print_self_multi_threaded", # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) is not thread-safe
497504
"test_dialects_transform_interpreter__transform_options_multi_threaded", # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) is not thread-safe
498505
"test_dialects_gpu_module-to-binary-rocdl__testGPUToASMBin_multi_threaded", # Due to global llvm-project/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp::GCNTrackers variable mutation
499-
500506
# integration tests
501507
"test_integration_dialects_linalg_opsrun__test_elemwise_builtin_multi_threaded", # Related to ctypes data races
502508
"test_integration_dialects_linalg_opsrun__test_elemwise_generic_multi_threaded", # Related to ctypes data races
503-
504509
# IR tests
505510
"test_ir_diagnostic_handler__testDiagnosticCallbackException_multi_threaded", # mlirEmitError is not thread-safe
506511
"test_ir_module__testParseSuccess_multi_threaded", # mlirOperationDump is not thread-safe
@@ -509,7 +514,7 @@ def run_construct_and_print_in_module(f):
509514
]
510515

511516

512-
def add_existing_tests(test_prefix: str = "_original_test"):
517+
def add_existing_tests(test_modules, test_prefix: str = "_original_test"):
513518
def decorator(test_cls):
514519
this_folder = Path(__file__).parent.absolute()
515520
test_cls.output_folder = tempfile.TemporaryDirectory()
@@ -531,16 +536,22 @@ def decorator(test_cls):
531536
test_mod = import_from_path(test_module_name, dst_filepath)
532537
for attr_name in dir(test_mod):
533538
is_test_fn = test_pattern is None and attr_name.startswith("test")
534-
is_test_fn |= test_pattern is not None and any([p in attr_name for p in test_pattern])
539+
is_test_fn |= test_pattern is not None and any(
540+
[p in attr_name for p in test_pattern]
541+
)
535542
if is_test_fn:
536543
obj = getattr(test_mod, attr_name)
537544
if callable(obj):
538545
test_name = f"{test_prefix}_{test_module_name.replace('/', '_')}__{attr_name}"
539-
def wrapped_test_fn(self, *args, __test_fn__=obj, __exec_fn__=exec_fn, **kwargs):
546+
547+
def wrapped_test_fn(
548+
self, *args, __test_fn__=obj, __exec_fn__=exec_fn, **kwargs
549+
):
540550
__exec_fn__(__test_fn__)
541551

542552
setattr(test_cls, test_name, wrapped_test_fn)
543553
return test_cls
554+
544555
return decorator
545556

546557

@@ -553,6 +564,7 @@ def multi_threaded(
553564
multithreaded_test_postfix: str = "_multi_threaded",
554565
):
555566
"""Decorator that runs a test in a multi-threaded environment."""
567+
556568
def decorator(test_cls):
557569
for name, test_fn in test_cls.__dict__.copy().items():
558570
if not (name.startswith(test_prefix) and callable(test_fn)):
@@ -566,7 +578,9 @@ def decorator(test_cls):
566578
):
567579
continue
568580

569-
def multi_threaded_test_fn(self, capfd, *args, __test_fn__=test_fn, **kwargs):
581+
def multi_threaded_test_fn(
582+
self, capfd, *args, __test_fn__=test_fn, **kwargs
583+
):
570584
barrier = threading.Barrier(num_workers)
571585

572586
def closure():
@@ -590,7 +604,9 @@ def closure():
590604
captured = capfd.readouterr()
591605
if len(captured.err) > 0:
592606
if "ThreadSanitizer" in captured.err:
593-
raise RuntimeError(f"ThreadSanitizer reported warnings:\n{captured.err}")
607+
raise RuntimeError(
608+
f"ThreadSanitizer reported warnings:\n{captured.err}"
609+
)
594610
else:
595611
pass
596612
# There are tests that write to stderr, we should ignore them
@@ -603,18 +619,19 @@ def closure():
603619
setattr(test_cls, test_new_name, multi_threaded_test_fn)
604620

605621
return test_cls
622+
606623
return decorator
607624

608625

609626
@multi_threaded(
610-
num_workers=6,
627+
num_workers=8,
611628
num_runs=20,
612-
skip_tests=tests_to_skip,
613-
xfail_tests=tests_to_xfail,
629+
skip_tests=TESTS_TO_SKIP,
630+
xfail_tests=TESTS_TO_XFAIL,
614631
)
615-
@add_existing_tests(test_prefix="_original_test")
632+
@add_existing_tests(test_modules=TEST_MODULES, test_prefix="_original_test")
616633
class TestAllMultiThreaded:
617-
@pytest.fixture(scope='class')
634+
@pytest.fixture(scope="class")
618635
def teardown(self):
619636
if hasattr(self, "output_folder"):
620637
self.output_folder.cleanup()

0 commit comments

Comments
 (0)