What's the fastest way to stack pytrees? #9855
Answered
by
NeilGirdhar
NeilGirdhar
asked this question in
Q&A
-
I was forced to convert a def stack(*elements: Array) -> Array:
return jnp.stack(elements)
with Timer() as timer:
trajectory = tree_map(stack, *results) # results has 5000 trees in it
log.info(f"Stacking took {timer.elapsed_str(3)}") # Stacking took 71.2 s!!! What's the fastest way to accomplish this in Jax? |
Beta Was this translation helpful? Give feedback.
Answered by
NeilGirdhar
Mar 11, 2022
Replies: 1 comment 2 replies
-
Switching to |
Beta Was this translation helpful? Give feedback.
2 replies
Answer selected by
NeilGirdhar
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Switching to
np.stack
brought the runtime down to 2s!