How to make vmap work with a namedtuple as output of a function? #6857
Answered
by
jakevdp
shailesh1729
asked this question in
Q&A
-
Following is my example code. It has two functions import jax
import jax.numpy as jnp
def f (x, y):
return x+y, x-y
x = jnp.array([1,2,3,4])
y = jnp.array([11,12,13,14])
vf = jax.vmap(f, in_axes=(0, 0), out_axes=(0, 0))
sums, differences = vf(x, y)
print(sums)
print(differences)
from typing import NamedTuple
class SD(NamedTuple):
u:int = 0
v:int = 0
def g(x, y):
return SD(u=x+y, v=x-y)
print(g(2,3))
vg = jax.vmap(g, in_axes=(0, 0), out_axes=(0, 0))
results = vg(x, y) I get the following error.
I had tried with dataclass too. I faced similar problems. How is the handling of NamedTuple different from a regular tuple in |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
May 29, 2021
Replies: 1 comment 1 reply
-
The |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
shailesh1729
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The
out_axes
should either be an integer, or have the same structure as the variables. So I believe bothout_axes=0
andout_axes=SD(0, 0)
will do what you wish.