Efficiently implementing nested loops with single array JAX function #10883
-
Hello, For example this function will behave like the following:
then my function will behave like (note that the below is just a dummy output to give an example)
The problem arises when I start handling multidimensional arrays for my input. For example my input is of the shape
I need to make sure of 3 things: I iterate over each input, I store each output, and I reformat my output to be properly formatted for other ML modules. To do this I run:
As you can imagine my code runs very slow and very quickly runs out of RAM. I'm struggling to implement more efficient loops with Any ideas on how to make the above code more efficient and/or rewrite it using Thanks! (if I get a solution on my own I'll update this thread accordingly) |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 5 replies
-
There may be some way to use a single vmap call, but you can definitely do:
By default, each call to vmap is adding another batch dimension to the left of the shape of the inputs/outputs. |
Beta Was this translation helpful? Give feedback.
-
Two simple solution: Reshape input and use single |
Beta Was this translation helpful? Give feedback.
There may be some way to use a single vmap call, but you can definitely do:
By default, each call to vmap is adding another batch dimension to the left of the shape of the inputs/outputs.