Skip to content
Discussion options

You must be logged in to vote

Hi, thanks for the question! As you've found, this sort of dynamic shape is not supported in JAX. The best you could do currently is probably to return arrays of a fixed maximum size, with only the first n elements populated.

In the future, this should hopefully get easier. There is some experimental work on "vmap with piles" (#13139), which would effectively allow you to represent your cs_s as a ragged array (i.e. a 2D array-like object where each row has a different length). You can search jax issues and pull requests for "dynamic shapes" to see some of the other progress toward this.

But again, in the short term, I think padding outputs to a maximum length would probably be your best b…

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@quattro
Comment options

@PhilipVinc
Comment options

Answer selected by quattro
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants