Static argument in vmapped function? #10712
Replies: 2 comments
-
"static" arguments in In your case, though, it sounds like you're asking how you can map over an argument like The best way to handle the particular example in your question would probably be a list comprehension: result = [f(i) for i in indices] |
Beta Was this translation helpful? Give feedback.
-
JAX dispatch computation asynchronously on GPU, so you can simply loop over indices. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello! First, thank you for all the work put into JAX. I've greatly enjoyed using it. I have a question. Is there any way to do something like the following?
This throws an error because "indices" becomes a BatchTracer, which is unhashable. I would like to be able to vmap/parallelize a function (that always returns a scalar) over a collection of 1d arrays that have different numbers of elements. I know one option is to use masking, but for my specific application, where speed/performance is important, the masking approach does not scale well.
Beta Was this translation helpful? Give feedback.
All reactions