Replies: 1 comment
-
I think the reason Also, you should make sure to use import jax
from jax import vmap, value_and_grad
from functools import partial
import jax.numpy as jnp
@jax.jit
@vmap # over outer-loop batch
def forward(tokens):
def computation(carry, token): # called once per token
value, grad = value_and_grad(inner_model)(inner_params, token)
carry += grad
return carry, value
total_grad, inner_outputs = jax.lax.scan(computation, jnp.zeros_like(inner_params), tokens)
return inner_outputs, total_grad
@partial(jax.jit, static_argnums=1)
@partial(vmap, in_axes=(0, None))
def forward2(tokens, group_size):
def computation(carry, token_group): # called once per group of tokens
value_and_grad_fn = partial(value_and_grad(inner_model), inner_params)
value_group, grad_group = vmap(value_and_grad_fn)(token_group)
carry += jnp.sum(grad_group, axis=0)
return carry, value_group
token_groups = tokens.reshape(-1, group_size, *tokens.shape[1:])
total_grad, inner_outputs = jax.lax.scan(computation, jnp.zeros_like(inner_params), token_groups)
return inner_outputs.reshape(-1, *inner_outputs.shape[2:]), total_grad
inner_params = jnp.ones(768)
inner_model = jnp.dot
tokens_batch = jnp.ones((12, 1024, 768))
print("ungrouped")
_ = jax.block_until_ready(forward(tokens_batch))
%timeit jax.block_until_ready(forward(tokens_batch))
print()
print("group size = 1")
_ = jax.block_until_ready(forward2(tokens_batch, 1))
%timeit jax.block_until_ready(forward2(tokens_batch, 1))
print()
print("group size = 4")
_ = jax.block_until_ready(forward2(tokens_batch, 4))
%timeit jax.block_until_ready(forward2(tokens_batch, 4))
print()
print("group size = 16")
_ = jax.block_until_ready(forward2(tokens_batch, 16))
%timeit jax.block_until_ready(forward2(tokens_batch, 16))
print()
print("group size = 64")
_ = jax.block_until_ready(forward2(tokens_batch, 64))
%timeit jax.block_until_ready(forward2(tokens_batch, 64))
print()
print("group size = 256")
_ = jax.block_until_ready(forward2(tokens_batch, 256))
%timeit jax.block_until_ready(forward2(tokens_batch, 256))
print()
|
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi!
I'm working on a meta-learning project, where I am building a outer model that works with the gradients of an inner model (so gradients of gradients).
My training code has the outer model feed tokens into an inner model, and perform updates based on gradients from that inner model, something like:
In an attempt to speed up this sequential process, I tried to feed an entire group of tokens at once into
inner_model
with a vmap (can't feed all tokens at once because we will OOM). I wrote something like:However, I noticed some counter-intuitive behavior when attempting to tune the value of GROUP_SIZE. It seems that increasing the value of GROUP_SIZE actually slows down even this simple
forward2
rather than speeding it up. Here are the numbers I get when running the above code on both CPU and GPU (i7-1270P laptop, A100 GPU on shared node)So I am wondering, how does Jax treat a scan-over-vmap operation internally? What could possibly make one large vmap slower than multiple smaller vmaps executed sequentially?
Beta Was this translation helpful? Give feedback.
All reactions