I have jax 0.5.2 and jaxlib 0.5.1 (the default versions that are installed using pip). But when trying to import the library:
from neural_tangents import stax
I get the error:
AttributeError: module 'jax.random' has no attribute 'KeyArray'
It looks like this had been depreciated in newer jax version so I installed jax==0.4.23. However, then I am unable to install the corresponding jaxlib version:
ERROR: Could not find a version that satisfies the requirement jaxlib==0.4.23 (from versions: 0.4.34, 0.4.35, 0.4.36, 0.4.38, 0.5.0, 0.5.1)
How can I fix this?