How do I scan / map / vmap over a function of variable sized inputs? #11171
-
Doing this: def foo(x):
return x.sum()
jax.lax.map(foo, [jnp.array([1, 2]), jnp.array([1, 2, 3])]) Causes This kind of thing can come up if you have batched data of different sizes, for example. |
Beta Was this translation helpful? Give feedback.
Answered by
YouJiacheng
Jun 20, 2022
Replies: 1 comment 6 replies
-
The standard way is padding data to the same size. |
Beta Was this translation helpful? Give feedback.
6 replies
Answer selected by
ahwillia
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The standard way is padding data to the same size.
For better efficiency, you can bucket your data by size, sample data by sampling a bucket then sampling a batch of data in this bucket instead of directly sampling a batch of data.