How can i set some modules to stop learning when training in jax? #10178
-
Just like pytorch, we can set one specific module to stop the gradient to learn by setting the require_grad=False. The module is like this class XXX(nn.Module): Thanks |
Beta Was this translation helpful? Give feedback.
Replies: 4 comments 4 replies
-
Option 1 model_output = model.apply(params, model_input)
detached_model_output = jax.lax.stop_gradient(model_output) Option 2 detached_params = jax.lax.stop_gradient(params)
model_output = model.apply(detached_params , model_input) Note that params can be a pytree, i.e. nested dict/tuple/list or custom pytree type. |
Beta Was this translation helpful? Give feedback.
-
And you may refer to |
Beta Was this translation helpful? Give feedback.
-
Use optax.set_to_zero together with optax.multi_transform. params = {
'a': { 'x1': ..., 'x2': ... },
'b': { 'x1': ..., 'x2': ... },
}
param_labels = {
'a': { 'x1': 'freeze', 'x2': 'train' },
'b': 'train',
}
optimizer_scheme = {
'train': optax.adam(...),
'freeze': optax.set_to_zero(),
}
optimizer = optax.multi_transform(optimizer_scheme, param_labels) See Freeze Parameters Example for details. - Taken from ayaka14732/tpu-starter, the 'Freeze certain model parameters' section. |
Beta Was this translation helpful? Give feedback.
-
If following @YouJiacheng's "Option 3" above and using Equinox as your neural network library (instead of Flax), then the docs for Equinox already include an example for handling frozen layers here. |
Beta Was this translation helpful? Give feedback.
Option 1
Option 2
Note that params can be a pytree, i.e. nested dict/tuple/list or custom pytree type.
This is suitable for
flax
andhaiku
:init
will give you params as a pytree,apply
will request a pytree params.Option 3
Use https://github.com/patrick-kidger/equinox, then your model will be a pytree.