-
Hi, I was looking for a way to convert a Something like this: from collections import namedtuple
import jax
import jax.numpy as jnp
Foo = namedtuple('Foo', 'a b')
foo = Foo(a=jnp.ones((10, 5, 2)), b=jnp.ones((10, 1)))
# Something equivalent to this where to output should be a list of length 10:
foo_list = [Foo(a=a, b=b) for a, b in zip(foo.a, foo.b)] I was looking into |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Jul 7, 2022
Replies: 1 comment 1 reply
-
I think |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
alonfnt
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I think
list(map(Foo, *foo))
should do what you want for the case whereFoo
is a simplenamedtuple
containing arrays.