File tree Expand file tree Collapse file tree 2 files changed +29
-1
lines changed Expand file tree Collapse file tree 2 files changed +29
-1
lines changed Original file line number Diff line number Diff line change @@ -2743,6 +2743,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
27432743 // __init__.py will subclass it with site-specific functionality and set a
27442744 // "Context" attribute on this module.
27452745 // ----------------------------------------------------------------------------
2746+
2747+ // Expose DefaultThreadPool to python
2748+ nb::class_<PyThreadPool>(m, " ThreadPool" )
2749+ .def (" __init__" , [](PyThreadPool &self) { new (&self) PyThreadPool (); })
2750+ .def (" get_max_concurrency" , &PyThreadPool::getMaxConcurrency);
2751+
27462752 nb::class_<PyMlirContext>(m, " _BaseContext" )
27472753 .def (" __init__" ,
27482754 [](PyMlirContext &self) {
@@ -2814,6 +2820,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
28142820 mlirContextEnableMultithreading (self.get (), enable);
28152821 },
28162822 nb::arg (" enable" ))
2823+ .def (" set_thread_pool" ,
2824+ [](PyMlirContext &self, PyThreadPool &pool) {
2825+ mlirContextSetThreadPool (self.get (), pool.get ());
2826+ })
28172827 .def (
28182828 " is_registered_operation" ,
28192829 [](PyMlirContext &self, std::string &name) {
Original file line number Diff line number Diff line change 2222#include " mlir-c/IR.h"
2323#include " mlir-c/IntegerSet.h"
2424#include " mlir-c/Transforms.h"
25- #include " mlir/Bindings/Python/NanobindAdaptors.h"
2625#include " mlir/Bindings/Python/Nanobind.h"
26+ #include " mlir/Bindings/Python/NanobindAdaptors.h"
2727#include " llvm/ADT/DenseMap.h"
28+ #include " llvm/Support/ThreadPool.h"
2829
2930namespace mlir {
3031namespace python {
@@ -158,6 +159,23 @@ class PyThreadContextEntry {
158159 FrameKind frameKind;
159160};
160161
162+ // / Wrapper around MlirLlvmThreadPool
163+ // / Python object owns the C++ thread pool
164+ class PyThreadPool {
165+ public:
166+ PyThreadPool () {
167+ ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>();
168+ }
169+ PyThreadPool (const PyThreadPool &) = delete ;
170+ PyThreadPool (PyThreadPool &&) = delete ;
171+
172+ int getMaxConcurrency () const { return ownedThreadPool->getMaxConcurrency (); }
173+ MlirLlvmThreadPool get () { return wrap (ownedThreadPool.get ()); }
174+
175+ private:
176+ std::unique_ptr<llvm::ThreadPoolInterface> ownedThreadPool;
177+ };
178+
161179// / Wrapper around MlirContext.
162180using PyMlirContextRef = PyObjectRef<PyMlirContext>;
163181class PyMlirContext {
You can’t perform that action at this time.
0 commit comments