Replies: 2 comments
-
In JAX, array shapes are static and represented by Python integers. Any operation on these static integers will be done at trace-time in Python: so For more on JAX's computational/tracing model, you might find this doc useful: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html |
Beta Was this translation helpful? Give feedback.
-
Partial application of |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I wrote a custom function to get the lower triangular values of a matrix (I am aware of
A[jnp.tril_indices(n, k=-1)]
)The function depends on a call to
math.comb
. Somehow JAX seems to correctly trace the function without this call. How does this work?For a specific case of a (5, 5) matrix and an expected output of shape (10,), inspecting the first few lines of the JAXPR shows
How does JAX know that shape of the flattened array should be (10,)? I know from previous testing that JAX fails to trace
math.comb
.My only guess is that during tracing JAX is getting the array size from the final loop index
l
and setting the shape to(l+1,)
.If this is what is happening it seems like black magic because the shape of
a
should have been defined before dropping into the loop.JAX Version: 0.4.8
Beta Was this translation helpful? Give feedback.
All reactions