How to do in-place parameter update? #12625
-
A common training pattern is to do However, this update copies the parameter, which doubles memory usage, can we do this update in-place? @jax.jit
def update(x):
return x + 1
def main():
# x takes 2GB memory
x = jnp.zeros((2 * 1024 * 1024 * 1024 // 4), dtype=jnp.float32)
# A training loop
for _ in range(10):
x = update(x)
|
Beta Was this translation helpful? Give feedback.
Answered by
tomhennigan
Oct 3, 2022
Replies: 2 comments 2 replies
-
This is supported via "buffer donation", change your update function to the following: def update(x):
return x + 1
update = jax.jit(update, donate_argnums=0) |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
imoneoi
-
Thanks! I used buffer donation, but my program stuck there forever with 0% GPU utilization. How can I find the problem? |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is supported via "buffer donation", change your update function to the following: