You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I keep getting this rather obscure error every once in a while when I run a jax.jit-compiled function. If I run my code, I don't get this error every time, just some times.
The code below points to this part errs = {key: jnp.nanmedian(jnp.abs(val["err"])).item() for key, val in hist.items()}. However, the function that runs this line is not jitted. Any idea on what might be causing this?
Thanks.
0%| | 0/30 [00:00<?, ?it/s]/home/gerardoduran/miniconda3/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:66: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
return lax_numpy.astype(arr, dtype)
33%|███████████████████████████████████████████ | 10/30 [01:01<02:03, 6.19s/it]
Traceback (most recent call last):
File "/home/gerardoduran/.../lr_variable_selection.py", line 202, in <module>
errs = run_and_eval(key, p_change, n_samples, K, alpha=alpha)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gerardoduran/.../lr_variable_selection.py", line 180, in run_and_eval
errs = {key: jnp.nanmedian(jnp.abs(val["err"])).item() for key, val in hist.items()}
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gerardoduran/miniconda3/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py", line 76, in _item
arr = core.concrete_or_error(np.asarray, a, context="This occurred in the item() method of jax.Array")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gerardoduran/miniconda3/lib/python3.12/site-packages/jax/_src/core.py", line 1511, in concrete_or_error
return force(val)
^^^^^^^^^^
File "/home/gerardoduran/miniconda3/lib/python3.12/site-packages/jax/_src/array.py", line 407, in __array__
return np.asarray(self._value, dtype=dtype, **kwds)
^^^^^^^^^^^
File "/home/gerardoduran/miniconda3/lib/python3.12/site-packages/jax/_src/profiler.py", line 335, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/gerardoduran/miniconda3/lib/python3.12/site-packages/jax/_src/array.py", line 621, in _value
self._npy_value = self._single_device_array_to_np_array() # type: ignore
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Program or fatal error occurred; computation may be invalid: FAILED_PRECONDITION: Can't read TensorCore tag value when not Open.
=== Source Location Trace: ===
platforms/asic_sw/driver/2a886c8/jxc/common/internal/tensor_node.cc:1000
learning/45eac/tpu/runtime/hal/internal/jxc/tpu_core_debug_interface_jxc_driver_impl.cc:63
learning/45eac/tpu/runtime/hal/internal/tpu_program_termination_validation.cc:49
I'm running jax on a TPU V3-8 and the jax version I'm using is
I get the same error even if I replace jnp for np on that line
[...]
File "/home/gerardoduran/documents/adaptive-weighted-likelihood-filter/experiments/lr_variable_selection.py", line 179, in run_and_eval
errs = {key: np.nanmedian(np.abs(val["err"])).item() for key, val in hist.items()}
^^^^^^^^^^^^^^^^^^
File "/home/gerardoduran/miniconda3/lib/python3.12/site-packages/jax/_src/array.py", line 407, in __array__
return np.asarray(self._value, dtype=dtype, **kwds)
^^^^^^^^^^^
[...]
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I keep getting this rather obscure error every once in a while when I run a jax.jit-compiled function. If I run my code, I don't get this error every time, just some times.
The code below points to this part
errs = {key: jnp.nanmedian(jnp.abs(val["err"])).item() for key, val in hist.items()}
. However, the function that runs this line is not jitted. Any idea on what might be causing this?Thanks.
I'm running jax on a TPU V3-8 and the jax version I'm using is
Update
I get the same error even if I replace
jnp
fornp
on that lineBeta Was this translation helpful? Give feedback.
All reactions