Implementing Nested Loops With Multiple Shape And Size Inputs #10928
-
Hello, For example this function will behave like the following:
then my function will behave like (note that the below is just a dummy output to give an example)
The problem arises when I start handling multidimensional arrays for my input. For example my inputs are of the shape
I need to make sure of 3 things: I iterate over each input, I store each output, and I reformat my output to be properly formatted for other ML modules. To do this I run:
As you can imagine my code runs very slow and very quickly runs out of RAM. I'm struggling to implement more efficient loops with Working off of this previous question I tried playing around with the
but I keep running into errors of the sort:
Any ideas on how to make the above code more efficient and/or rewrite it using Thanks! (if I get a solution on my own I'll update this thread accordingly) |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 7 replies
-
I think you want: jax.vmap(
jax.vmap(
jax.vmap(
MyFunction,
(0, None, None)
),
(0, None, None)
),
(0, None, None)
)(inputs, input2, input3) This is saying that the first 3 axes of |
Beta Was this translation helpful? Give feedback.
-
@davisyoshida ’s solution is correct, but I suggest using closure(or def foo(inputs, input2, input3)
@jax.vmap
@jax.vmap
@jax.vmap
def f(inputs_unmapped):
return MyFunction(inputs_unmapped, input2, input3)
return f(inputs) |
Beta Was this translation helpful? Give feedback.
-
@YouJiacheng and @davisyoshida - if I could mark both of your answers as "the answer," I would. I clearly have a lot to learn and I'm grateful for having both of you to help me out. Thank you for all of your help! |
Beta Was this translation helpful? Give feedback.
@davisyoshida ’s solution is correct, but I suggest using closure(or
partial
) to simplify the code.