Skip to content
Discussion options

You must be logged in to vote

Thanks for the question - I'm not sure what might have caused this change. Doing some investigation, it seems to not be related to the jaxlib version, but rather to the jax version, and changed sometime between jax 0.3.14 and 0.3.15. I modified your script to print the versions:

import jax
import jaxlib
import jax.numpy as jnp

print(f"jax: {jax.__version__}")
print(f"jaxlib: {jaxlib.__version__}")

@jax.jit
def f(Y):
    return jnp.sum(Y**2)

Y = jnp.ones(2)

for i in range(5):
    print(f._cache_size(), end=' ')
    f(Y)
print()

Here are the outputs of two runs:

jax: 0.3.14
jaxlib: 0.3.14
0 1 1 1 1 
jax: 0.3.15
jaxlib: 0.3.14
0 1 2 2 2 

Among changes to jax/_src/api.py in the date ran…

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@jakevdp
Comment options

@jakevdp
Comment options

@jakevdp
Comment options

@pipme
Comment options

Answer selected by pipme
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants