Skip to content

Commit 9fb18e1

Browse files
committed
[skip-ci] More tests and added a lock to _cext.register_operation
1 parent 65a1006 commit 9fb18e1

File tree

2 files changed

+66
-8
lines changed

2 files changed

+66
-8
lines changed

mlir/lib/Bindings/Python/MainModule.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,14 @@ PYBIND11_MODULE(_mlir, m, py::mod_gil_not_used()) {
7373
[dialectClass, replace](py::type opClass) -> py::type {
7474
std::string operationName =
7575
opClass.attr("OPERATION_NAME").cast<std::string>();
76-
PyGlobals::get().registerOperationImpl(operationName, opClass,
77-
replace);
7876

77+
// Use PyGlobals::withInstance instead of PyGlobals::get()
78+
// to prevent data race in multi-threaded context
79+
// Error raised in ir/opeation.py testKnownOpView test
80+
PyGlobals::withInstance([&](PyGlobals& instance) {
81+
instance.registerOperationImpl(operationName, opClass, replace);
82+
return 0;
83+
});
7984
// Dict-stuff the new opClass by name onto the dialect class.
8085
py::object opClassName = opClass.attr("__name__");
8186
dialectClass.attr(opClassName) = opClass;

mlir/test/python/multithreaded_tests.py

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import concurrent.futures
22
import functools
3+
import gc
34
import importlib.util
45
import sys
56
import threading
@@ -42,9 +43,55 @@ def copy_and_update(src_filepath: Path, dst_filepath: Path):
4243
writer.write(src_line)
4344

4445

46+
def run(f):
47+
f()
48+
49+
50+
def constructAndPrintInModule(f):
51+
print("\nTEST:", f.__name__)
52+
with Context(), Location.unknown():
53+
module = Module.create()
54+
with InsertionPoint(module.body):
55+
f()
56+
print(module)
57+
58+
59+
def run_with_context_and_location(f):
60+
print("\nTEST:", f.__name__)
61+
with Context(), Location.unknown():
62+
f()
63+
return f
64+
65+
4566
test_modules = [
46-
"execution_engine",
47-
# "pass_manager",
67+
("execution_engine", run), # Fail
68+
("pass_manager", run), # Fail
69+
70+
# Dialects tests
71+
("dialects/affine", constructAndPrintInModule), # Fail
72+
("dialects/vector", run_with_context_and_location), # Fail
73+
74+
# IR tests
75+
("ir/affine_expr", run), # Pass
76+
("ir/affine_map", run), # Pass
77+
("ir/array_attributes", run), # Pass
78+
("ir/attributes", run), # Pass
79+
("ir/blocks", run), # Pass
80+
("ir/builtin_types", run), # Pass
81+
("ir/context_managers", run), # Pass
82+
("ir/debug", run), # Fail
83+
("ir/diagnostic_handler", run), # Fail
84+
("ir/dialects", run), # Fail
85+
("ir/exception", run), # Fail
86+
("ir/insertion_point", run), # Pass
87+
("ir/insertion_point", run), # Pass
88+
("ir/integer_set", run), # Pass
89+
("ir/location", run), # Pass
90+
("ir/module", run), # Pass but may fail randomly on mlirOperationDump in testParseSuccess
91+
("ir/operation", run), # Pass
92+
("ir/symbol_table", run), # Pass
93+
("ir/value", run), # Fail/Crash
94+
4895
]
4996

5097

@@ -54,7 +101,7 @@ def decorator(test_cls):
54101
test_cls.output_folder = tempfile.TemporaryDirectory()
55102
output_folder = Path(test_cls.output_folder.name)
56103

57-
for test_module_name in test_modules:
104+
for test_module_name, exec_fn in test_modules:
58105
src_filepath = this_folder / f"{test_module_name}.py"
59106
dst_filepath = (output_folder / f"{test_module_name}.py").absolute()
60107
if not dst_filepath.parent.exists():
@@ -66,8 +113,8 @@ def decorator(test_cls):
66113
obj = getattr(test_mod, attr_name)
67114
if callable(obj):
68115
test_name = f"{test_prefix}_{test_module_name.replace('/', '_')}__{attr_name}"
69-
def wrapped_test_fn(*args, __test_fn__=obj, **kwargs):
70-
__test_fn__()
116+
def wrapped_test_fn(self, *args, __test_fn__=obj, __exec_fn__=exec_fn, **kwargs):
117+
__exec_fn__(__test_fn__)
71118

72119
setattr(test_cls, test_name, wrapped_test_fn)
73120
return test_cls
@@ -99,6 +146,10 @@ def closure():
99146
for _ in range(num_runs):
100147
__test_fn__(self, *args, **kwargs)
101148

149+
barrier.wait()
150+
gc.collect()
151+
assert Context._get_live_count() == 0
152+
102153
with concurrent.futures.ThreadPoolExecutor(
103154
max_workers=num_workers
104155
) as executor:
@@ -114,7 +165,9 @@ def closure():
114165
if "ThreadSanitizer" in captured.err:
115166
raise RuntimeError(f"ThreadSanitizer reported warnings:\n{captured.err}")
116167
else:
117-
raise RuntimeError(f"Other error:\n{captured.err}")
168+
pass
169+
# There are tests that write to stderr, we should ignore them
170+
# raise RuntimeError(f"Other error:\n{captured.err}")
118171

119172
setattr(test_cls, f"{name}_multi_threaded", multi_threaded_test_fn)
120173

0 commit comments

Comments
 (0)