-
import jax
import jax.numpy as jnp
@jax.jit
def f(Y):
return jnp.sum(Y**2)
Y = jnp.ones(2)
for i in range(5):
print(f._cache_size())
f(Y) The above prints |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
Thanks for the question - I'm not sure what might have caused this change. Doing some investigation, it seems to not be related to the jaxlib version, but rather to the jax version, and changed sometime between jax 0.3.14 and 0.3.15. I modified your script to print the versions: import jax
import jaxlib
import jax.numpy as jnp
print(f"jax: {jax.__version__}")
print(f"jaxlib: {jaxlib.__version__}")
@jax.jit
def f(Y):
return jnp.sum(Y**2)
Y = jnp.ones(2)
for i in range(5):
print(f._cache_size(), end=' ')
f(Y)
print() Here are the outputs of two runs:
Among changes to |
Beta Was this translation helpful? Give feedback.
Thanks for the question - I'm not sure what might have caused this change. Doing some investigation, it seems to not be related to the jaxlib version, but rather to the jax version, and changed sometime between jax 0.3.14 and 0.3.15. I modified your script to print the versions:
Here are the outputs of two runs:
Among changes to
jax/_src/api.py
in the date ran…