Trying to pmap a list of PyTree #19133
Unanswered
ThatBitBoss
asked this question in
Q&A
Replies: 1 comment
-
Hi - thanks for the question! So to do what you wish here, instead of passing a list of pytrees, you should pass a pytree of arrays. Here's a simple example of what I mean, using import jax
def f(x):
return x['A'] + x['B']
list_of_trees = [
{'A': 1, 'B': 2},
{'A': 5, 'B': 10}
]
jax.vmap(f)(list_of_trees) # wrong!
# ValueError: vmap was requested to map its argument along axis 0, which implies
# that its rank should be at least 1, but is only 0 (its shape is ())
tree_of_arrays = {'A': jnp.array([1, 5]), 'B': jnp.array([2, 10])}
result = jax.vmap(f)(tree_of_arrays) # correct!
print(result)
# [ 3 15] Your code uses a Hope that helps! |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Here is my code
When I go over an array of integers - it seems to work
But the list of PyTree seems to throw an error.
Ultimately I want to pmap to a execusteGoal_Hunger and I am obviously missing or messing up something.
Beta Was this translation helpful? Give feedback.
All reactions