Skip to content

Commit 04e4c69

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[mosaic_gpu] Handle older jaxlibs in the profiler module
`measure` now raises a `RuntimeError` if the available `jaxlib` does not have the required custom calls. PiperOrigin-RevId: 698351662
1 parent f442d40 commit 04e4c69

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

jax/experimental/mosaic/gpu/profiler.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,15 @@
3636
try:
3737
from jax._src.lib import mosaic_gpu as mosaic_gpu_lib
3838
except ImportError:
39-
pass
39+
has_registrations = False
4040
else:
41-
for name, handler in mosaic_gpu_lib._mosaic_gpu_ext.registrations():
42-
xla_client.register_custom_call_target(
43-
name, handler, platform="CUDA", api_version=1
44-
)
41+
# TODO(slebedev): Remove the if once the minimum jaxlib is 0.4.36.
42+
has_registrations = hasattr(mosaic_gpu_lib._mosaic_gpu_ext, "registrations")
43+
if has_registrations:
44+
for name, handler in mosaic_gpu_lib._mosaic_gpu_ext.registrations():
45+
xla_client.register_custom_call_target(
46+
name, handler, platform="CUDA", api_version=1
47+
)
4548

4649
# ruff: noqa: F405
4750
# mypy: ignore-errors
@@ -80,6 +83,11 @@ def measure(
8083
Returns:
8184
The return value of ``f`` and the elapsed time in milliseconds.
8285
"""
86+
if not has_registrations:
87+
raise RuntimeError(
88+
"This function requires jaxlib >=0.4.36 with CUDA support."
89+
)
90+
8391
if not (args or kwargs):
8492
# We require at least one argument and at least one output to ensure
8593
# that there is a data dependency between `_event_record` calls in

0 commit comments

Comments
 (0)