Skip to content
Discussion options

You must be logged in to vote

Unfortunately, it's not possible to do what you're asking with jax.vmap. We've had some experiments along these lines (see e.g. #16541 and related work) but nothing that's yet ready to use.

I suspect your best option here would be to pad all batches to the same length, and then use vmap on the padded version.

Replies: 1 comment 5 replies

Comment options

You must be logged in to vote
5 replies
@ajithmoola
Comment options

@jakevdp
Comment options

@ajithmoola
Comment options

@ajithmoola
Comment options

@jakevdp
Comment options

Answer selected by ajithmoola
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants