Efficient per example training with combination of shared and per example params #6093
Unanswered
adam-hartshorne
asked this question in
Q&A
Replies: 1 comment 4 replies
-
Would moving the gradient inside the vmap, then accumulating the shared grad manually work? i.e.: import numpy as np
import jax
import jax.numpy as jnp
import optax
def cost_func(shared_p, example_p, example):
return jnp.inner(shared_p, example) + jnp.inner(example_p, example)
def step(params, optimizer, optimizer_state, data):
value, (example_shared_grads, per_example_grads) = jax.vmap(
jax.value_and_grad(cost_func, argnums=(0, 1)),
(None, 0, 0)
)(params['shared_param'], params['per_example_param'], data)
shared_grad = jnp.sum(example_shared_grads, axis=0)
grads = {'shared_param': shared_grad, 'per_example_param': per_example_grads}
updates, opt_state = optimizer.update(grads, optimizer_state, params)
return value, optax.apply_updates(params, updates), opt_state
def main():
params = {
'shared_param': jnp.ones(10),
'per_example_param': jnp.ones((123, 10)),
}
optimizer = optax.adam(0.1)
opt_state = optimizer.init(params)
data = jnp.array(np.random.randn(123, 10))
step(params, optimizer, opt_state, data) |
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
If I wish to optimize a model on a per example basis, I was wondering what is the best way to deal with a combination of shared parameters and ones that are being optimized for each example in the dataset.
Currently, I just pass all the data in as one, plus the parameter dict, into my cost function, then extract from the dict all the parameters use vmap on each example. Then I use vmap to call an inner function in order to do per example from the data and the per example param, plus all the shared params. This is highly memory inefficient as all the data is getting passed into the cost functions and thus gradients for all of it is being calculated at once, rather than on a per example basis.
e.g. Minimal example of the sort of poorly devised structure I currently have
N = number of examples in dataset
M = number of data points in each example
data = jnp.array(np.random.random(N,M))
params = {'shared_param' : jnp.ones(M),
'per_example_param': jnp.ones(N, M)
}
Any help much appreciated how I should set this up so that a) I can do call the cost function on a per example basis and b) that the per_example_param is correctly trained.
Beta Was this translation helpful? Give feedback.
All reactions