jit with dataclass
#11022
-
when i use ....
[ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
2.4825400e+02, 0.0000000e+00, -3.5320038e-01],
[ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
0.0000000e+00, 2.4925099e+02, 0.0000000e+00],
[ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
-3.5320038e-01, 0.0000000e+00, 2.5025000e+02]], dtype=float32), N=501, E_GS_jnp=DeviceArray(0.06263407, dtype=float32)). The error was:
TypeError: JAX DeviceArray, like numpy.ndarray, is not hashable.
At:
/home/jimmy/anaconda3/envs/qrl/lib/python3.7/site-packages/jax/interpreters/xla.py(1293): __hash__
<string>(2): __hash__
test_lmg_ham_jax.py(200): test_evolve_jnp
test_lmg_ham_jax.py(216): <module> Any idea to enable jit with dataclass? |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Jun 8, 2022
Replies: 1 comment 1 reply
-
You can register the dataclass as a pytree if you want it to be compatible with |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
sharadmv
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You can register the dataclass as a pytree if you want it to be compatible with
jit
and other transforms; for more information see Extending PyTrees. If it's not clear how to apply that to the situation you are facing, you might edit the question to add a minimal reproducible example so that it's more clear to others exactly what you're trying to do.