Skip to content

Commit 55b55e6

Browse files
committed
Enable multi-threading in Jax Context with shared thread pool
1 parent d0b71fa commit 55b55e6

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

jax/_src/interpreters/mlir.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,15 @@ def module_to_bytecode(module: ir.Module) -> bytes:
593593

594594
# Translation rules
595595

596+
# Create one global thread pool that can be shared between multiple ir.Contexts
597+
# and enabling multi-threading
598+
# TODO: remove this check after jaxlib 0.5.4
599+
if hasattr(ir, "ThreadPool"):
600+
global_thread_pool = ir.ThreadPool()
601+
else:
602+
global_thread_pool = None
603+
604+
596605
class JaxIrContext(ir.Context):
597606
def __init__(self, *args, **kwargs):
598607
# Note: we're very intentionally *not* calling the __init__() of our
@@ -607,12 +616,16 @@ def make_ir_context() -> ir.Context:
607616
context.append_dialect_registry(upstream_dialects)
608617
context.load_all_available_dialects()
609618

610-
# If threading is enabled, each MLIR context will keep alive a thread pool.
611-
# Since we cache MLIR modules (and hence contexts), this means we might keep
612-
# several threads alive for each cache entry. This is a terrible idea. However
613-
# we don't do any heavy computation on MLIR modules from Python anyway, so we
614-
# just disable threading.
615-
context.enable_multithreading(False)
619+
# TODO: remove this check after v0.5.4 jaxlib
620+
if global_thread_pool is not None:
621+
context.set_thread_pool(global_thread_pool)
622+
else:
623+
# If threading is enabled, each MLIR context will keep alive a thread pool.
624+
# Since we cache MLIR modules (and hence contexts), this means we might keep
625+
# several threads alive for each cache entry. This is a terrible idea. However
626+
# we don't do any heavy computation on MLIR modules from Python anyway, so we
627+
# just disable threading.
628+
context.enable_multithreading(False)
616629
# TODO(bartchr): Once JAX is released with SDY, remove the if.
617630
if dialects.sdy:
618631
dialects.sdy.register_dialect(context)

0 commit comments

Comments
 (0)