From 922415d694262c71a8c78fa3cd6244ff34280171 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 6 Mar 2025 14:46:46 +0100 Subject: [PATCH 1/4] Added PyThreadPool as wrapper around MlirLlvmThreadPool in MLIR python bindings --- mlir/lib/Bindings/Python/IRCore.cpp | 10 ++++++++++ mlir/lib/Bindings/Python/IRModule.h | 20 +++++++++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 12793f7dd15be..1ec52a1a9bcd4 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2743,6 +2743,12 @@ void mlir::python::populateIRCore(nb::module_ &m) { // __init__.py will subclass it with site-specific functionality and set a // "Context" attribute on this module. //---------------------------------------------------------------------------- + + // Expose DefaultThreadPool to python + nb::class_(m, "ThreadPool") + .def("__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); }) + .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency); + nb::class_(m, "_BaseContext") .def("__init__", [](PyMlirContext &self) { @@ -2814,6 +2820,10 @@ void mlir::python::populateIRCore(nb::module_ &m) { mlirContextEnableMultithreading(self.get(), enable); }, nb::arg("enable")) + .def("set_thread_pool", + [](PyMlirContext &self, PyThreadPool &pool) { + mlirContextSetThreadPool(self.get(), pool.get()); + }) .def( "is_registered_operation", [](PyMlirContext &self, std::string &name) { diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 1ed6240a6ca69..b7bbd646d982e 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -22,9 +22,10 @@ #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" #include "mlir-c/Transforms.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/Support/ThreadPool.h" namespace mlir { namespace python { @@ -158,6 +159,23 @@ class PyThreadContextEntry { FrameKind frameKind; }; +/// Wrapper around MlirLlvmThreadPool +/// Python object owns the C++ thread pool +class PyThreadPool { +public: + PyThreadPool() { + ownedThreadPool = std::make_unique(); + } + PyThreadPool(const PyThreadPool &) = delete; + PyThreadPool(PyThreadPool &&) = delete; + + int getMaxConcurrency() const { return ownedThreadPool->getMaxConcurrency(); } + MlirLlvmThreadPool get() { return wrap(ownedThreadPool.get()); } + +private: + std::unique_ptr ownedThreadPool; +}; + /// Wrapper around MlirContext. using PyMlirContextRef = PyObjectRef; class PyMlirContext { From 5b6ca6aa0661d9efe4abf2db48e59d16f6192435 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 6 Mar 2025 15:58:15 +0100 Subject: [PATCH 2/4] Added a test --- mlir/test/python/ir/context_lifecycle.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/mlir/test/python/ir/context_lifecycle.py b/mlir/test/python/ir/context_lifecycle.py index c20270999425e..93a98e7f8e9f5 100644 --- a/mlir/test/python/ir/context_lifecycle.py +++ b/mlir/test/python/ir/context_lifecycle.py @@ -47,3 +47,17 @@ assert '"mlir.ir.Context._CAPIPtr"' in repr(c4_capsule) c5 = mlir.ir.Context._CAPICreate(c4_capsule) assert c4 is c5 +c4 = None +c5 = None +gc.collect() + +# Create a global threadpool and use it in two contexts +tp = mlir.ir.ThreadPool() +assert tp.get_max_concurrency() > 0 +c5 = mlir.ir.Context() +c5.enable_multithreading(False) +c5.set_thread_pool(tp) +c6 = mlir.ir.Context() +c6.enable_multithreading(False) +c6.set_thread_pool(tp) +assert mlir.ir.Context._get_live_count() == 2 From e64830122e0cac98b8a2832307f1134be63d49af Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 7 Mar 2025 11:44:27 +0100 Subject: [PATCH 3/4] Added get_num_threads and _mlir_thread_pool_ptr methods to _BaseContext Added thread_pool arg to the constructor: `mlir.ir.Context(thread_pool=tp)` --- mlir/include/mlir-c/IR.h | 9 +++++++++ mlir/lib/Bindings/Python/IRCore.cpp | 18 +++++++++++++++++- mlir/lib/Bindings/Python/IRModule.h | 7 +++++++ mlir/lib/CAPI/IR/IR.cpp | 8 ++++++++ mlir/python/mlir/_mlir_libs/__init__.py | 9 +++++++-- mlir/test/python/ir/context_lifecycle.py | 15 ++++++++++++--- 6 files changed, 60 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index d562da1f90757..001660ee51311 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -162,6 +162,15 @@ MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context, MLIR_CAPI_EXPORTED void mlirContextSetThreadPool(MlirContext context, MlirLlvmThreadPool threadPool); +/// Gets the number of threads of the thread pool of the context when +/// multithreading is enabled. Returns 1 if no multithreading. +MLIR_CAPI_EXPORTED unsigned mlirContextGetNumThreads(MlirContext context); + +/// Gets the thread pool of the context when enabled multithreading, otherwise +/// an assertion is raised. +MLIR_CAPI_EXPORTED MlirLlvmThreadPool +mlirContextGetThreadPool(MlirContext context); + //===----------------------------------------------------------------------===// // Dialect API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 1ec52a1a9bcd4..22d6d117573b9 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2747,7 +2747,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { // Expose DefaultThreadPool to python nb::class_(m, "ThreadPool") .def("__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); }) - .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency); + .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency) + .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr); nb::class_(m, "_BaseContext") .def("__init__", @@ -2822,8 +2823,23 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("enable")) .def("set_thread_pool", [](PyMlirContext &self, PyThreadPool &pool) { + // we should disable multi-threading first before setting + // new thread pool otherwise the assert in + // MLIRContext::setThreadPool will be raised. + mlirContextEnableMultithreading(self.get(), false); mlirContextSetThreadPool(self.get(), pool.get()); }) + .def("get_num_threads", + [](PyMlirContext &self) { + return mlirContextGetNumThreads(self.get()); + }) + .def("_mlir_thread_pool_ptr", + [](PyMlirContext &self) { + MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get()); + std::stringstream ss; + ss << pool.ptr; + return ss.str(); + }) .def( "is_registered_operation", [](PyMlirContext &self, std::string &name) { diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index b7bbd646d982e..9befcce725bb7 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -11,6 +11,7 @@ #define MLIR_BINDINGS_PYTHON_IRMODULES_H #include +#include #include #include @@ -172,6 +173,12 @@ class PyThreadPool { int getMaxConcurrency() const { return ownedThreadPool->getMaxConcurrency(); } MlirLlvmThreadPool get() { return wrap(ownedThreadPool.get()); } + std::string _mlir_thread_pool_ptr() const { + std::stringstream ss; + ss << ownedThreadPool.get(); + return ss.str(); + } + private: std::unique_ptr ownedThreadPool; }; diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 6cd9ba2aef233..649f3b7056fb0 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -114,6 +114,14 @@ void mlirContextSetThreadPool(MlirContext context, unwrap(context)->setThreadPool(*unwrap(threadPool)); } +unsigned mlirContextGetNumThreads(MlirContext context) { + return unwrap(context)->getNumThreads(); +} + +MlirLlvmThreadPool mlirContextGetThreadPool(MlirContext context) { + return wrap(&unwrap(context)->getThreadPool()); +} + //===----------------------------------------------------------------------===// // Dialect API. //===----------------------------------------------------------------------===// diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py index d021dde05dd87..c480a0035313d 100644 --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -148,13 +148,18 @@ def process_initializer_module(module_name): break class Context(ir._BaseContext): - def __init__(self, load_on_create_dialects=None, *args, **kwargs): + def __init__( + self, load_on_create_dialects=None, thread_pool=None, *args, **kwargs + ): super().__init__(*args, **kwargs) self.append_dialect_registry(get_dialect_registry()) for hook in post_init_hooks: hook(self) if not disable_multithreading: - self.enable_multithreading(True) + if thread_pool is None: + self.enable_multithreading(True) + else: + self.set_thread_pool(thread_pool) if load_on_create_dialects is not None: logger.debug( "Loading all dialects from load_on_create_dialects arg %r", diff --git a/mlir/test/python/ir/context_lifecycle.py b/mlir/test/python/ir/context_lifecycle.py index 93a98e7f8e9f5..230db8277c8e7 100644 --- a/mlir/test/python/ir/context_lifecycle.py +++ b/mlir/test/python/ir/context_lifecycle.py @@ -55,9 +55,18 @@ tp = mlir.ir.ThreadPool() assert tp.get_max_concurrency() > 0 c5 = mlir.ir.Context() -c5.enable_multithreading(False) c5.set_thread_pool(tp) +assert c5.get_num_threads() == tp.get_max_concurrency() +assert c5._mlir_thread_pool_ptr() == tp._mlir_thread_pool_ptr() c6 = mlir.ir.Context() -c6.enable_multithreading(False) c6.set_thread_pool(tp) -assert mlir.ir.Context._get_live_count() == 2 +assert c6.get_num_threads() == tp.get_max_concurrency() +assert c6._mlir_thread_pool_ptr() == tp._mlir_thread_pool_ptr() +c7 = mlir.ir.Context(thread_pool=tp) +assert c7.get_num_threads() == tp.get_max_concurrency() +assert c7._mlir_thread_pool_ptr() == tp._mlir_thread_pool_ptr() +assert mlir.ir.Context._get_live_count() == 3 +c5 = None +c6 = None +c7 = None +gc.collect() From a557554c7f60f21f66d0c261fd45c4082c4269e9 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Sat, 8 Mar 2025 14:36:13 +0100 Subject: [PATCH 4/4] Raise error if disable_multithreading and thread_pool is given --- mlir/python/mlir/_mlir_libs/__init__.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py index c480a0035313d..083a9075fe4c5 100644 --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -155,6 +155,13 @@ def __init__( self.append_dialect_registry(get_dialect_registry()) for hook in post_init_hooks: hook(self) + if disable_multithreading and thread_pool is not None: + raise ValueError( + "Context constructor has given thread_pool argument, " + "but disable_multithreading flag is True. " + "Please, set thread_pool argument to None or " + "set disable_multithreading flag to False." + ) if not disable_multithreading: if thread_pool is None: self.enable_multithreading(True)