@@ -593,6 +593,15 @@ def module_to_bytecode(module: ir.Module) -> bytes:
593593
594594# Translation rules
595595
596+ # Create one global thread pool that can be shared between multiple ir.Contexts
597+ # and enabling multi-threading
598+ # TODO: remove this check after jaxlib 0.5.4
599+ if hasattr (ir , "ThreadPool" ):
600+ global_thread_pool = ir .ThreadPool ()
601+ else :
602+ global_thread_pool = None
603+
604+
596605class JaxIrContext (ir .Context ):
597606 def __init__ (self , * args , ** kwargs ):
598607 # Note: we're very intentionally *not* calling the __init__() of our
@@ -607,12 +616,16 @@ def make_ir_context() -> ir.Context:
607616 context .append_dialect_registry (upstream_dialects )
608617 context .load_all_available_dialects ()
609618
610- # If threading is enabled, each MLIR context will keep alive a thread pool.
611- # Since we cache MLIR modules (and hence contexts), this means we might keep
612- # several threads alive for each cache entry. This is a terrible idea. However
613- # we don't do any heavy computation on MLIR modules from Python anyway, so we
614- # just disable threading.
615- context .enable_multithreading (False )
619+ # TODO: remove this check after v0.5.4 jaxlib
620+ if global_thread_pool is not None :
621+ context .set_thread_pool (global_thread_pool )
622+ else :
623+ # If threading is enabled, each MLIR context will keep alive a thread pool.
624+ # Since we cache MLIR modules (and hence contexts), this means we might keep
625+ # several threads alive for each cache entry. This is a terrible idea. However
626+ # we don't do any heavy computation on MLIR modules from Python anyway, so we
627+ # just disable threading.
628+ context .enable_multithreading (False )
616629 # TODO(bartchr): Once JAX is released with SDY, remove the if.
617630 if dialects .sdy :
618631 dialects .sdy .register_dialect (context )
0 commit comments