Return pytrees without implicit copy #6237
-
Suppose I have a large from collections import namedtuple
import jax
import jax.numpy as jnp
# state contains 100 different JAX arrays
State = namedtuple("State", ["v{i}" for i in range(100)])
@jax.jit
def foo(state):
new_v2 = state.v0 + state.v1
return state._replace(v2=new_v2)
state = State(*[jnp.ones(1000) * i for i in range(100)])
new_state = foo(state) It is my understanding that the state object returned by Is there a way to prevent this, or is this behavior just baked into how JAX handles pytrees? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 4 replies
-
You can use I'm also not sure if the IDs not being equal even implies that there was a copy on the GPU, since they could just be two different references to the same location in memory (I'm out of my depth here). |
Beta Was this translation helpful? Give feedback.
-
I think it's a reasonable feature request that JAX should forward buffers without copying them if you do not modify them. We already have logic along those lines for trivial computations, but not for computations that also do some amount of computation. (It's not really related to pytrees, just to do with how JAX lowers to XLA computations.) |
Beta Was this translation helpful? Give feedback.
I think it's a reasonable feature request that JAX should forward buffers without copying them if you do not modify them. We already have logic along those lines for trivial computations, but not for computations that also do some amount of computation.
(It's not really related to pytrees, just to do with how JAX lowers to XLA computations.)