Ergonomic way to extract a single iteration output from a scan #20054
-
It's common to extract the activations for a single hidden layer in a network, but this gets annoying when using a scan over parameters. Here's a toy example: import jax
import jax.numpy as jnp
layers = 16
dim = 128
Ws = jax.random.normal(jax.random.PRNGKey(0), (layers, dim, dim))
x = jax.random.normal(jax.random.PRNGKey(1), (dim,))
def f(carry, W):
h = W @ carry
return h, h
final_state, all_hidden = jax.lax.scan(f, x, Ws)
the_state_I_want = all_hidden[3] I'm not 100% sure, but I think this requires XLA to instantiate the entire Usually, returning a list all the activations from your network is fine because users can grab whichever ones they want, and rely on DCE to avoid keeping the unused ones around. To get the same memory requirement with def f2(carry, W):
x, counter, write_out = carry
h = W @ x
write_out = jnp.where(counter == 3, h, write_out)
counter = counter + 1
return (h, counter, write_out), None
init_carry = (x, 0, jnp.zeros_like(x))
(final_state, _, the_state_I_want), _ = jax.lax.scan(f2, init_carry, Ws) When you're scanning more complex functions, it's pretty intrusive to implement something like this, since the layer implementation needs to be aware it's being used for scan. Does anyone have any ideas for a cleaner way to accomplish this? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 5 replies
-
Here's a wrapper that does it, although I'll need to jump through some more hoops to make it work with Flax. def scan_and_extract_carry(f, init, xs, iters):
@wraps(f)
def inner(carry, x):
orig_carry, iter_to_val, counter = carry
new_carry, out = f(orig_carry, x)
iter_to_val = {
i: jax.tree_map(
lambda curr_val, new_val: jnp.where(counter == i, new_val, curr_val),
curr_tree, new_carry
)
for i, curr_tree in iter_to_val.items()
}
return (new_carry, iter_to_val, counter + 1), out
storage = jax.tree_map(jnp.zeros_like, init)
iter_to_val = {i: storage for i in iters}
new_init = (init, iter_to_val, 0)
(final_carry, stored_carries, _), outputs = jax.lax.scan(inner, new_init, xs)
return final_carry, stored_carries, outputs
def f(val, x):
carry = val + x
return carry, -x
final, stored, outputs = scan_and_extract_carry(
f,
init=0,
xs=jnp.arange(10),
iters=[3, 5, 7]
) I think this probably does save memory, since compiling output for a simple scan->slice combination seems to lead to holding all the carries in memory at once (the input to Example@jax.jit
def slice_hidden(init, xs):
final_state, all_hidden = jax.lax.scan(lambda carry, x: (carry + x, carry), init, xs)
return final_state, all_hidden[3]
lowered = jax.jit(slice_hidden).lower(0, jnp.arange(10))
compiled = lowered.compile()
print(compiled.as_text())
|
Beta Was this translation helpful? Give feedback.
I see - in that case your approach is probably best. I don't think XLA will fuse the indexing with the scan (though you could check by outputting the optimized HLO)