-
Hi all, thanks again for developing such a fantastic library and toolkit. My lab has used it extensively in our research. We've recently developed an ultra-sparse model for PCA that performs very well along with a substantial speedup due to JIT during inference. However there is a core task that we need to compute in post-inference analyses that we would like to speed up as well, but cannot seem to make it JIT friendly. The issue is that the sub-vector I would greatly appreciate any thoughts or suggestions. @partial(jax.jit, static_argnums=(1,))
def get_credset_v2(params, rho=0.9):
l_dim, z_dim, p_dim = params.alpha.shape
idxs = jnp.argsort(-params.alpha, axis=-1)
cs = {}
for zdx in range(z_dim):
cs_s = []
for ldx in range(l_dim):
# idxs for all feature at this zdx and ldx
p_idxs = idxs[ldx, zdx, :]
# compute the cumulative sum over these sorted values
p_sums = jnp.cumsum(params.alpha[ldx, zdx, p_idxs])
# get the index at first value that satisfy cumsum at least rho
min_p_gts = jnp.argmin(p_sums >= rho)
# pull sub-vector as cred-set
cs_s.append(p_idxs[: min_p_gts + 1])
cs["z" + str(zdx)] = cs_s
return cs |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Hi, thanks for the question! As you've found, this sort of dynamic shape is not supported in JAX. The best you could do currently is probably to return arrays of a fixed maximum size, with only the first n elements populated. In the future, this should hopefully get easier. There is some experimental work on "vmap with piles" (#13139), which would effectively allow you to represent your But again, in the short term, I think padding outputs to a maximum length would probably be your best bet for getting this working with JAX transforms like JIT. |
Beta Was this translation helpful? Give feedback.
Hi, thanks for the question! As you've found, this sort of dynamic shape is not supported in JAX. The best you could do currently is probably to return arrays of a fixed maximum size, with only the first n elements populated.
In the future, this should hopefully get easier. There is some experimental work on "vmap with piles" (#13139), which would effectively allow you to represent your
cs_s
as a ragged array (i.e. a 2D array-like object where each row has a different length). You can search jax issues and pull requests for "dynamic shapes" to see some of the other progress toward this.But again, in the short term, I think padding outputs to a maximum length would probably be your best b…