-
Notifications
You must be signed in to change notification settings - Fork 58
Closed
Description
For any Jax Kernel, I get this error.
I tried running this on every GPU type and it failed
import jax
import jax.numpy as jnp
# A, B are tensors on GPU
@jax.jit
def solve(A: jax.Array, B: jax.Array, N: int) -> jax.Array:
return A + BE0303 05:59:07.961540 41 cuda_dnn.cc:454] Loaded runtime CuDNN library: 9.5.1 but source was compiled with: 9.8.0. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library. If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
E0303 05:59:07.968398 41 cuda_dnn.cc:454] Loaded runtime CuDNN library: 9.5.1 but source was compiled with: 9.8.0. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library. If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
E0303 05:59:08.030018 41 cuda_dnn.cc:454] Loaded runtime CuDNN library: 9.5.1 but source was compiled with: 9.8.0. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library. If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
E0303 05:59:08.038274 41 cuda_dnn.cc:454] Loaded runtime CuDNN library: 9.5.1 but source was compiled with: 9.8.0. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library. If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
Traceback (most recent call last):
File "runner/runner.py", line 216, in _run_single_test_case
File "/usr/local/lib/python3.12/dist-packages/jax/_src/numpy/array_creation.py", line 94, in zeros
return lax.full(shape, 0, dtype, sharding=sharding)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/jax/_src/lax/lax.py", line 3381, in full
fill_value = _convert_element_type(fill_value, fill_dtype, weak_type)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/jax/_src/lax/lax.py", line 1689, in _convert_element_type
return convert_element_type_p.bind(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/jax/_src/core.py", line 632, in bind
return self._true_bind(*args, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/jax/_src/core.py", line 648, in _true_bind
return self.bind_with_trace(prev_trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/jax/_src/lax/lax.py", line 5012, in _convert_element_type_bind_with_trace
operand = core.Primitive.bind_with_trace(convert_element_type_p, trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/jax/_src/core.py", line 660, in bind_with_trace
return trace.process_primitive(self, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/jax/_src/core.py", line 1205, in process_primitive
return primitive.impl(*args, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/jax/_src/dispatch.py", line 91, in apply_primitive
outs = fun(*args)
^^^^^^^^^^
jax.errors.JaxRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Test case failed
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels