Replies: 1 comment
-
Hi, I believe I have accomplished a similar thing using def train_step(state, batch):
...
def initialization(model, learning_rate, input_size, seed, weight_decay):
# keep this function as written
...
# create functions -- assume we have a list of learning rates, seeds, and weight decays
initializers = [
jtu.Partial(initialization, model=model, learning_rate=lr, input_size=input_size, seed=seed, weight_decay=wd)
for lr, seed, wd in zip(learning_rates, seeds, weight_decays)
]
# vmap over initialisers
states = jax.vmap(jtu.Partial(jax.lax.switch, branches=initializers))(jnp.arange(len(learning_rates)))
# vectorise the stepping function over the first dimension of the states
fn_step = jax.vmap(train_step, in_axes=(0, None)) This is quite a crude example, but I think you can play around with it to achieve what you are asking for Note: I did not test this code, but I think the idea is clear. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I made a script to perform a gridsearch over hyperparameters using vmap. I want to pass parameters such as the seed, weight decay, and learning rate to train a model using optax and flax. This works fine for variables like seed that are not changed inside vmap. However, the learning rate is modified inside the optax optimizer resulting in a side effect.
How can you pass variables to vmap that are changed during execution? Is this even possible?
My code is as follows:
Beta Was this translation helpful? Give feedback.
All reactions