File tree Expand file tree Collapse file tree 1 file changed +13
-5
lines changed
jax/experimental/mosaic/gpu Expand file tree Collapse file tree 1 file changed +13
-5
lines changed Original file line number Diff line number Diff line change 3636try :
3737 from jax ._src .lib import mosaic_gpu as mosaic_gpu_lib
3838except ImportError :
39- pass
39+ has_registrations = False
4040else :
41- for name , handler in mosaic_gpu_lib ._mosaic_gpu_ext .registrations ():
42- xla_client .register_custom_call_target (
43- name , handler , platform = "CUDA" , api_version = 1
44- )
41+ # TODO(slebedev): Remove the if once the minimum jaxlib is 0.4.36.
42+ has_registrations = hasattr (mosaic_gpu_lib ._mosaic_gpu_ext , "registrations" )
43+ if has_registrations :
44+ for name , handler in mosaic_gpu_lib ._mosaic_gpu_ext .registrations ():
45+ xla_client .register_custom_call_target (
46+ name , handler , platform = "CUDA" , api_version = 1
47+ )
4548
4649# ruff: noqa: F405
4750# mypy: ignore-errors
@@ -80,6 +83,11 @@ def measure(
8083 Returns:
8184 The return value of ``f`` and the elapsed time in milliseconds.
8285 """
86+ if not has_registrations :
87+ raise RuntimeError (
88+ "This function requires jaxlib >=0.4.36 with CUDA support."
89+ )
90+
8391 if not (args or kwargs ):
8492 # We require at least one argument and at least one output to ensure
8593 # that there is a data dependency between `_event_record` calls in
You can’t perform that action at this time.
0 commit comments