@@ -2747,7 +2747,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
27472747 // Expose DefaultThreadPool to python
27482748 nb::class_<PyThreadPool>(m, " ThreadPool" )
27492749 .def (" __init__" , [](PyThreadPool &self) { new (&self) PyThreadPool (); })
2750- .def (" get_max_concurrency" , &PyThreadPool::getMaxConcurrency);
2750+ .def (" get_max_concurrency" , &PyThreadPool::getMaxConcurrency)
2751+ .def (" _mlir_thread_pool_ptr" , &PyThreadPool::_mlir_thread_pool_ptr);
27512752
27522753 nb::class_<PyMlirContext>(m, " _BaseContext" )
27532754 .def (" __init__" ,
@@ -2822,8 +2823,23 @@ void mlir::python::populateIRCore(nb::module_ &m) {
28222823 nb::arg (" enable" ))
28232824 .def (" set_thread_pool" ,
28242825 [](PyMlirContext &self, PyThreadPool &pool) {
2826+ // we should disable multi-threading first before setting
2827+ // new thread pool otherwise the assert in
2828+ // MLIRContext::setThreadPool will be raised.
2829+ mlirContextEnableMultithreading (self.get (), false );
28252830 mlirContextSetThreadPool (self.get (), pool.get ());
28262831 })
2832+ .def (" get_num_threads" ,
2833+ [](PyMlirContext &self) {
2834+ return mlirContextGetNumThreads (self.get ());
2835+ })
2836+ .def (" _mlir_thread_pool_ptr" ,
2837+ [](PyMlirContext &self) {
2838+ MlirLlvmThreadPool pool = mlirContextGetThreadPool (self.get ());
2839+ std::stringstream ss;
2840+ ss << pool.ptr ;
2841+ return ss.str ();
2842+ })
28272843 .def (
28282844 " is_registered_operation" ,
28292845 [](PyMlirContext &self, std::string &name) {
0 commit comments