-
For the following piece of code, I have stated that the shape will be a static parameter, then why JAX is retracing everything when shape is changed? from functools import partial
import equinox as eqx
import numpy as np
from jax import jit
from jax import numpy as jnp
from jax import random as jrd
from jax import tree as jtr
from numpyro import distributions as dist
def rand_key():
return jrd.PRNGKey(np.random.randint(0, 10**7))
N = 100
norms = [
dist.Normal(
jrd.uniform(key=rand_key(), minval=-10, maxval=10),
jrd.uniform(key=rand_key(), minval=0, maxval=2),
)
for _ in range(N)
]
@partial(jit, static_argnums=(1,))
@eqx.debug.assert_max_traces(max_traces=1)
def jit_sample(key, shape):
return jnp.array(
jtr.map(
lambda d: d.sample(key, shape),
norms,
is_leaf=lambda x: isinstance(x, dist.Normal),
)
)
samples = jit_sample(jrd.PRNGKey(0), (1000,))
print(samples.shape)
samples = jit_sample(jrd.PRNGKey(1), (100,))
print(samples.shape) |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Jun 22, 2024
Replies: 1 comment
-
"static parameter" doesn't mean that the parameter value will never change. It means that the compilation depends on the parameter value. By design, if the value of a static parameter changes, the function will be recompiled. I hope that's clear! |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
Qazalbash
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
"static parameter" doesn't mean that the parameter value will never change. It means that the compilation depends on the parameter value. By design, if the value of a static parameter changes, the function will be recompiled.
I hope that's clear!