Skip to content

Commit cd777fb

Browse files
gabrielokspre-commit-ci[bot]ssmmnn11HCookieanaprietonem
authored
feat(training): Refactor optimizer creation to support custom and torch optimizers (#588)
## Description Refactoring of the optimizer creation to support custom and torch optimizers. ## What problem does this change solve? This PR enables users to select different PyTorch optimizers and also allows the use of custom ones, such as AdEMAMix. ## What issue or task does this change relate to? <!-- link to Issue Number --> ## Additional notes ## <!-- Include any additional information, caveats, or considerations that the reviewer should be aware of. --> ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md) <!-- readthedocs-preview anemoi-training start --> ---- 📚 Documentation preview 📚: https://anemoi-training--588.org.readthedocs.build/en/588/ <!-- readthedocs-preview anemoi-training end --> <!-- readthedocs-preview anemoi-graphs start --> ---- 📚 Documentation preview 📚: https://anemoi-graphs--588.org.readthedocs.build/en/588/ <!-- readthedocs-preview anemoi-graphs end --> <!-- readthedocs-preview anemoi-models start --> ---- 📚 Documentation preview 📚: https://anemoi-models--588.org.readthedocs.build/en/588/ <!-- readthedocs-preview anemoi-models end --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Simon Lang <[email protected]> Co-authored-by: Harrison Cook <[email protected]> Co-authored-by: Ana Prieto Nemesio <[email protected]>
1 parent ca6f732 commit cd777fb

File tree

15 files changed

+657
-54
lines changed

15 files changed

+657
-54
lines changed

LICENCES/APPLE_ML_ACKNOWLEDGEMENTS

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
Acknowledgements
2+
Portions of our AdEMAMix implementation may utilize the following copyrighted
3+
material, the use of which is hereby acknowledged.
4+
5+
_____________________
6+
7+
The Pytorch team (Pytorch)
8+
9+
From PyTorch:
10+
11+
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
12+
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
13+
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
14+
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
15+
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
16+
Copyright (c) 2011-2013 NYU (Clement Farabet)
17+
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
18+
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
19+
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
20+
21+
From Caffe2:
22+
23+
Copyright (c) 2016-present, Facebook Inc. All rights reserved.
24+
25+
All contributions by Facebook:
26+
Copyright (c) 2016 Facebook Inc.
27+
28+
All contributions by Google:
29+
Copyright (c) 2015 Google Inc.
30+
All rights reserved.
31+
32+
All contributions by Yangqing Jia:
33+
Copyright (c) 2015 Yangqing Jia
34+
All rights reserved.
35+
36+
All contributions by Kakao Brain:
37+
Copyright 2019-2020 Kakao Brain
38+
39+
All contributions by Cruise LLC:
40+
Copyright (c) 2022 Cruise LLC.
41+
All rights reserved.
42+
43+
All contributions by Arm:
44+
Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates
45+
46+
All contributions from Caffe:
47+
Copyright(c) 2013, 2014, 2015, the respective contributors
48+
All rights reserved.
49+
50+
All other contributions:
51+
Copyright(c) 2015, 2016 the respective contributors
52+
All rights reserved.
53+
54+
Caffe2 uses a copyright model similar to Caffe: each contributor holds
55+
copyright over their contributions to Caffe2. The project versioning records
56+
all such contribution and copyright details. If a contributor wants to further
57+
mark their specific copyright on a particular contribution, they should
58+
indicate their copyright solely in the commit message of the change when it is
59+
committed.
60+
61+
All rights reserved.
62+
63+
Redistribution and use in source and binary forms, with or without
64+
modification, are permitted provided that the following conditions are met:
65+
66+
1. Redistributions of source code must retain the above copyright
67+
notice, this list of conditions and the following disclaimer.
68+
69+
2. Redistributions in binary form must reproduce the above copyright
70+
notice, this list of conditions and the following disclaimer in the
71+
documentation and/or other materials provided with the distribution.
72+
73+
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
74+
and IDIAP Research Institute nor the names of its contributors may be
75+
used to endorse or promote products derived from this software without
76+
specific prior written permission.
77+
78+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
79+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
80+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
81+
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
82+
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
83+
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
84+
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
85+
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
86+
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
87+
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
88+
POSSIBILITY OF SUCH DAMAGE.
89+
90+
91+
Google Deepmind (Jax)
92+
Licensed under the Apache License, Version 2.0 (the "License");
93+
you may not use this file except in compliance with the License.
94+
You may obtain a copy of the License at
95+
96+
http://www.apache.org/licenses/LICENSE-2.0
97+
98+
Unless required by applicable law or agreed to in writing, software
99+
distributed under the License is distributed on an "AS IS" BASIS,
100+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
101+
See the License for the specific language governing permissions and
102+
limitations under the License.
103+
104+
105+
Google Deepmind (Optax)
106+
Licensed under the Apache License, Version 2.0 (the "License");
107+
you may not use this file except in compliance with the License.
108+
You may obtain a copy of the License at
109+
110+
http://www.apache.org/licenses/LICENSE-2.0
111+
112+
Unless required by applicable law or agreed to in writing, software
113+
distributed under the License is distributed on an "AS IS" BASIS,
114+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
115+
See the License for the specific language governing permissions and
116+
limitations under the License.

LICENCES/APPLE_ML_ADEMAMIX_LICENSE

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
MIT License
2+
3+
Copyright © 2024 Apple Inc.
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.
22+
23+
24+
-------------------------------------------------------------------------------
25+
SOFTWARE DISTRIBUTED WITH ADEMAMIX:
26+
27+
The AdEMAMix provided code includes a number of subcomponents with separate
28+
copyright notices and license terms - please see the file ACKNOWLEDGEMENTS.
29+
-------------------------------------------------------------------------------

NOTICE.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
This product includes third-party software components developed by Apple Inc.
2+
Specifically, it incorporates the "AdEMAMix" optimizer implementation,
3+
which is made available under the MIT License.
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
############
2+
Optimizers
3+
############
4+
5+
Optimizers are responsible for updating the model parameters during
6+
training based on the computed gradients from the loss function. In
7+
``anemoi-training``, optimizers are configured in the training
8+
configuration file under ``config.training.optimizer``. By default,
9+
optimizers are instantiated using a Hydra-style ``_target_`` entry,
10+
allowing full flexibility to specify both standard PyTorch optimizers
11+
and custom implementations.
12+
13+
The optimizer configuration is handled internally by the
14+
``BaseGraphModule`` class through its ``_create_optimizer_from_config``
15+
method, which reads the provided configuration and creates the
16+
corresponding optimizer object. Additional settings, such as learning
17+
rate schedulers and warm-up phases, are also defined and managed within
18+
the same module.
19+
20+
**************************
21+
Configuring an Optimizer
22+
**************************
23+
24+
An optimizer can be defined in the training configuration file using its
25+
Python import path as the ``_target_``. For example, to use the standard
26+
Adam optimizer:
27+
28+
.. code:: yaml
29+
30+
optimizer:
31+
_target_: torch.optim.Adam
32+
betas: [0.9, 0.95]
33+
weight_decay: 0.1
34+
35+
The ``BaseGraphModule`` automatically injects the learning rate from
36+
``config.training.lr``. The optimizer configuration can therefore focus
37+
on algorithm-specific parameters.
38+
39+
**************************
40+
Learning Rate Schedulers
41+
**************************
42+
43+
Learning rate schedulers can be attached to any optimizer to control the
44+
evolution of the learning rate during training. By default,
45+
``anemoi-training`` uses a ``CosineLRScheduler`` from ``timm.scheduler``
46+
with optional warm-up steps and minimum learning rate.
47+
48+
The scheduler is created by ``BaseGraphModule._create_scheduler`` and
49+
returned to the trainer together with the optimizer in
50+
``configure_optimizers``. Currently, the scheduler is hard-coded to
51+
``CosineLRScheduler``, but in the future this will be made more flexible
52+
to allow configurable schedulers.
53+
54+
The scheduler is returned in a dictionary of the form:
55+
56+
.. code:: python
57+
58+
{
59+
"optimizer": optimizer,
60+
"lr_scheduler": {
61+
"scheduler": scheduler,
62+
"interval": "step",
63+
},
64+
}
65+
66+
********************
67+
AdEMAMix Optimizer
68+
********************
69+
70+
``AdEMAMix`` is a custom optimizer implemented in
71+
``anemoi.training.optimizers.AdEMAMix.py`` and taken from the `Apple ML
72+
AdEMAMix project <https://github.com/apple/ml-ademamix>`_. It combines
73+
elements of Adam and exponential moving average (EMA) mixing for
74+
improved stability and generalization.
75+
76+
The optimizer maintains **three exponential moving averages (EMAs)** of
77+
the gradients. See <https://arxiv.org/abs/2409.03137> for more details.
78+
79+
***********************
80+
Configuration in YAML
81+
***********************
82+
83+
An example configuration for using ``AdEMAMix`` is shown below:
84+
85+
.. code:: yaml
86+
87+
optimizer:
88+
_target_: anemoi.training.optimizers.AdEMAMix.AdEMAMix
89+
betas: [0.9, 0.999, 0.9999]
90+
alpha: 2.0
91+
weight_decay: 0.01
92+
beta3_warmup: 1000
93+
alpha_warmup: 1000
94+
95+
**************************
96+
Implementation Reference
97+
**************************
98+
99+
.. automodule:: anemoi.training.optimizers.AdEMAMix
100+
:members:
101+
:no-undoc-members:
102+
:show-inheritance:

training/src/anemoi/training/config/training/default.yaml

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,40 @@ swa:
3737
enabled: False
3838
lr: 1.e-4
3939

40-
# Optimizer settings
40+
# =====================================================================
41+
# Optimizer configuration
42+
# =====================================================================
4143
optimizer:
42-
zero: False # use ZeroRedundancyOptimizer ; saves memory for larger models
43-
kwargs:
44-
betas: [0.9, 0.95]
44+
# ---------------------------------------------------------------
45+
# Choose optimizer type (_target_ approach)
46+
# ---------------------------------------------------------------
47+
# Default optimizer: AdamW
48+
_target_: torch.optim.AdamW
49+
50+
# ---------------------------------------------------------------
51+
# Common optimizer parameters
52+
# ---------------------------------------------------------------
53+
# Learning rate is defined elsewhere
54+
#kwargs:
55+
betas: [0.9, 0.95] # β₁, β₂ for Adam-style optimizers
56+
57+
# ---------------------------------------------------------------
58+
# Optional: configuration for AdEMAMix (custom optimizer)
59+
# Uncomment the lines below to enable it
60+
# ---------------------------------------------------------------
61+
# _target_: anemoi.training.optimizers.AdEMAMix.AdEMAMix # Custom optimizer
62+
# betas: [0.9, 0.95, 0.9999] # β₁, β₂, β₃
63+
# alpha: 8.0 # Mixing factor controlling EMA fusion
64+
# beta3_warmup: 260000 # Warm-up steps for β₃ (in iterations)
65+
# alpha_warmup: 260000 # Warm-up steps for α (in iterations)
66+
# weight_decay: 0.01
67+
68+
# Optional: configuration for ZeroRedundancyOptimizer
69+
# _target_: torch.distributed.optim.ZeroRedundancyOptimizer
70+
# optimizer_class:
71+
# _target_: torch.optim.AdamW
72+
# _partial_: true
73+
# betas: [0.9, 0.95]
4574

4675
# select model
4776
model_task: anemoi.training.train.tasks.GraphForecaster

training/src/anemoi/training/config/training/diffusion.yaml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,10 @@ swa:
3939

4040
# Optimizer settings
4141
optimizer:
42-
zero: False
43-
kwargs:
44-
weight_decay: 0.1
45-
betas: [0.9, 0.95]
46-
eps: 1e-7
42+
_target_: torch.optim.AdamW
43+
weight_decay: 0.1
44+
betas: [0.9, 0.95]
45+
eps: 1e-7
4746

4847
# select model
4948
model_task: anemoi.training.train.tasks.GraphDiffusionForecaster

training/src/anemoi/training/config/training/ensemble.yaml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,8 @@ swa:
3939

4040
# Optimizer settings
4141
optimizer:
42-
zero: False # use ZeroRedundancyOptimizer ; saves memory for larger models
43-
kwargs:
44-
betas: [0.9, 0.95]
42+
_target_: torch.optim.AdamW
43+
betas: [0.9, 0.95]
4544

4645
# select model
4746
model_task: anemoi.training.train.tasks.GraphEnsForecaster

training/src/anemoi/training/config/training/interpolator.yaml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,8 @@ swa:
3939

4040
# Optimizer settings
4141
optimizer:
42-
zero: False # use ZeroRedundancyOptimizer ; saves memory for larger models
43-
kwargs:
44-
betas: [0.9, 0.95]
42+
_target_: torch.optim.AdamW
43+
betas: [0.9, 0.95]
4544

4645
# select model
4746
model_task: anemoi.training.train.tasks.GraphInterpolator

training/src/anemoi/training/config/training/lam.yaml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,8 @@ swa:
3939

4040
# Optimizer settings
4141
optimizer:
42-
zero: False # use ZeroRedundancyOptimizer ; saves memory for larger models
43-
kwargs:
44-
betas: [0.9, 0.95]
42+
_target_: torch.optim.AdamW
43+
betas: [0.9, 0.95]
4544

4645
# select model
4746
model_task: anemoi.training.train.tasks.GraphForecaster

training/src/anemoi/training/config/training/stretched.yaml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,8 @@ swa:
3939

4040
# Optimizer settings
4141
optimizer:
42-
zero: False # use ZeroRedundancyOptimizer ; saves memory for larger models
43-
kwargs:
44-
betas: [0.9, 0.95]
42+
_target_: torch.optim.AdamW
43+
betas: [0.9, 0.95]
4544

4645
# select model
4746
model_task: anemoi.training.train.tasks.GraphForecaster

0 commit comments

Comments
 (0)