Skip to content
Discussion options

You must be logged in to vote

Thanks for the question! JAX's tracing behavior here is perhaps a bit confusing, but essentially it will evaluate and flatten any Python control flow based on static quantities like array shapes. In the case of your convolve function. you can see this by calling jax.make_jaxpr, which prints the jaxpr for the function:

from jax import make_jaxpr
make_jaxpr(convolve)(x, w)
{ lambda  ; a b.
  let c = broadcast_in_dim[ broadcast_dimensions=(  )
                            shape=(1,) ] 0
      d = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(), start_index_map=(0,))
                  indices_are_sorted=True
                  slice_sizes=(3,)
        …

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@kamalkraj
Comment options

Answer selected by kamalkraj
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants