Skip to content

Commit 44cd7a0

Browse files
authored
【Hackathon 8th No.12】在 PaddleScience 中实现 SOAP 优化器 (#1102)
* Add SOAP optimizer * update comments * Add copyright * Fix docs * fix * resolve reviewer issues * add soap config * fix * fix * resolve reviewer issues
1 parent e7b6c03 commit 44cd7a0

File tree

7 files changed

+678
-1
lines changed

7 files changed

+678
-1
lines changed

docs/zh/api/optimizer.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@
1111
- OptimizerList
1212
- RMSProp
1313
- SGD
14+
- SOAP
1415
show_root_heading: true
1516
heading_level: 3

docs/zh/examples/allen_cahn.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat -P ./dataset/
1010
# windows
1111
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat --create-dirs -o ./dataset/allen_cahn.mat
12+
# Using Adam optimizer
1213
python allen_cahn_piratenet.py
14+
# Using SOAP optimizer
15+
python allen_cahn_piratenet.py TRAIN.optim=soap TRAIN.lr_scheduler.warmup_epoch=5
1316
```
1417

1518
=== "模型评估命令"
@@ -19,7 +22,10 @@
1922
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat -P ./dataset/
2023
# windows
2124
# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/AllenCahn/allen_cahn.mat --create-dirs -o ./dataset/allen_cahn.mat
25+
# Using Adam pretrained model
2226
python allen_cahn_piratenet.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/AllenCahn/allen_cahn_piratenet_pretrained.pdparams
27+
# Using SOAP pretrained model
28+
python allen_cahn_piratenet.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/AllenCahn/allen_cahn_piratenet_soap_pretrained.pdparams
2329
```
2430

2531
=== "模型导出命令"
@@ -41,6 +47,7 @@
4147
| 预训练模型 | 指标 |
4248
|:--| :--|
4349
| [allen_cahn_piratenet_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/AllenCahn/allen_cahn_piratenet_pretrained.pdparams) | L2Rel.u: 1.2e-05 |
50+
| [allen_cahn_piratenet_soap_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/AllenCahn/allen_cahn_piratenet_soap_pretrained.pdparams) | L2Rel.u: 6.8e-6 |
4451

4552
## 1. 背景简介
4653

examples/allen_cahn/allen_cahn_piratenet.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,15 @@ def gen_label_batch(input_batch):
133133
lr_scheduler = ppsci.optimizer.lr_scheduler.ExponentialDecay(
134134
**cfg.TRAIN.lr_scheduler
135135
)()
136-
optimizer = ppsci.optimizer.Adam(lr_scheduler)(model)
136+
137+
if cfg.TRAIN.optim == "adam":
138+
optimizer = ppsci.optimizer.Adam(lr_scheduler)(model)
139+
elif cfg.TRAIN.optim == "soap":
140+
optimizer = ppsci.optimizer.SOAP(lr_scheduler)(model)
141+
else:
142+
raise ValueError(
143+
f"cfg.TRAIN.optim should be in ['adam','soap'], but got '{cfg.TRAIN.optim}'."
144+
)
137145

138146
# set validator
139147
tx_star = misc.cartesian_product(t_star, x_star).astype(dtype)

examples/allen_cahn/conf/allen_cahn_piratenet.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,15 @@ TRAIN:
5454
save_freq: 10
5555
eval_during_train: true
5656
eval_freq: 10
57+
optim: adam
5758
lr_scheduler:
5859
epochs: ${TRAIN.epochs}
5960
iters_per_epoch: ${TRAIN.iters_per_epoch}
6061
learning_rate: 1.0e-3
6162
gamma: 0.9
6263
decay_steps: 5000
6364
by_epoch: false
65+
warmup_epoch: 0
6466
batch_size: 8192
6567
pretrained_model_path: null
6668
checkpoint_path: null

ppsci/optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from ppsci.optimizer import lr_scheduler
1818
from ppsci.optimizer.optimizer import LBFGS
1919
from ppsci.optimizer.optimizer import SGD
20+
from ppsci.optimizer.optimizer import SOAP
2021
from ppsci.optimizer.optimizer import Adam
2122
from ppsci.optimizer.optimizer import AdamW
2223
from ppsci.optimizer.optimizer import Momentum
@@ -32,6 +33,7 @@
3233
"RMSProp",
3334
"OptimizerList",
3435
"lr_scheduler",
36+
"SOAP",
3537
]
3638

3739

ppsci/optimizer/optimizer.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from paddle.incubate import optimizer as incubate_optim
2828
from typing_extensions import Literal
2929

30+
from ppsci.optimizer.soap import SOAP as SOAP_impl
3031
from ppsci.utils import logger
3132
from ppsci.utils import misc
3233

@@ -495,6 +496,104 @@ def _apply_decay_param_fun(self, name):
495496
return name not in self.no_weight_decay_param_name_list
496497

497498

499+
class SOAP:
500+
"""
501+
Improving and Stabilizing Shampoo using Adam. Implements SOAP algorithm (https://arxiv.org/abs/2409.11321).
502+
503+
Args:
504+
learning_rate (float, optional):
505+
The learning rate to use. defaults to 0.003.
506+
beta1 (float, optional):
507+
Adam's betas parameters beta1. defaults to 0.95.
508+
beta2 (float, optional):
509+
Adam's betas parameters beta2. defaults to 0.95.
510+
shampoo_beta (float, optional):
511+
If >= 0, use this beta for the preconditioner (L and R in paper, state['GG'] below) moving average instead of betas[1].
512+
defaults to -1.
513+
epsilon (float, optional):
514+
Adam's epsilon for numerical stability. defaults to 1e-08.
515+
weight_decay (float, optional): weight decay coefficient. defaults to 0.01.
516+
precondition_frequency (int, optional):
517+
How often to update the preconditioner. defaults to 10.
518+
max_precond_dim (int, optional):
519+
Maximum dimension of the preconditioner.
520+
Set to 10000, so that we exclude most common vocab sizes while including layers. defaults to 10000.
521+
merge_dims (bool, optional):
522+
Whether or not to merge dimensions of the preconditioner. defaults to `False`.
523+
precondition_1d (bool, optional):
524+
Whether or not to precondition 1D gradients. defaults to `False`.
525+
normalize_grads (bool, optional):
526+
Whether or not to normalize gradients per layer.
527+
Helps at large precondition_frequency (~100 in our experiments),
528+
but hurts performance at small precondition_frequency (~10 in our experiments). defaults to `False`.
529+
data_format (str, optional):
530+
Data format of the input for convolutional layers.
531+
Should be "channels_last" for data_format of NHWC and "channels_first" for NCHW. defaults to `channels_first`.
532+
correct_bias (bool, optional):
533+
Whether or not to use bias correction in Adam. defaults to `True`.
534+
535+
Examples:
536+
>>> import ppsci
537+
>>> model = ppsci.arch.MLP(("x",), ("u",), 5, 20)
538+
>>> opt = ppsci.optimizer.SOAP(1e-3)(model)
539+
"""
540+
541+
def __init__(
542+
self,
543+
learning_rate: float = 3e-3,
544+
beta1: float = 0.95,
545+
beta2: float = 0.95,
546+
shampoo_beta: float = -1,
547+
epsilon: float = 1e-8,
548+
weight_decay: float = 0.01,
549+
precondition_frequency: int = 10,
550+
max_precond_dim: int = 10000, #
551+
merge_dims: bool = False, # Merge dimensions till the product of the dimensions is less than or equal to max_precond_dim.
552+
precondition_1d: bool = False,
553+
normalize_grads: bool = False,
554+
data_format: str = "channels_first",
555+
correct_bias: bool = True,
556+
):
557+
self.learning_rate = learning_rate
558+
self.beta1 = beta1
559+
self.beta2 = beta2
560+
self.shampoo_beta = shampoo_beta
561+
self.epsilon = epsilon
562+
self.weight_decay = weight_decay
563+
self.precondition_frequency = precondition_frequency
564+
self.max_precond_dim = max_precond_dim
565+
self.merge_dims = merge_dims
566+
self.precondition_1d = precondition_1d
567+
self.normalize_grads = normalize_grads
568+
self.data_format = data_format
569+
self.correct_bias = correct_bias
570+
571+
def __call__(self, model_list: Union[nn.Layer, Tuple[nn.Layer, ...]]):
572+
# model_list is None in static graph
573+
if not isinstance(model_list, (tuple, list)):
574+
model_list = (model_list,)
575+
parameters = (
576+
sum([m.parameters() for m in model_list], []) if model_list else None
577+
)
578+
opt = SOAP_impl(
579+
parameters=parameters,
580+
learning_rate=self.learning_rate,
581+
beta1=self.beta1,
582+
beta2=self.beta2,
583+
shampoo_beta=self.shampoo_beta,
584+
epsilon=self.epsilon,
585+
weight_decay=self.weight_decay,
586+
precondition_frequency=self.precondition_frequency,
587+
max_precond_dim=self.max_precond_dim,
588+
merge_dims=self.merge_dims,
589+
precondition_1d=self.precondition_1d,
590+
normalize_grads=self.normalize_grads,
591+
data_format=self.data_format,
592+
correct_bias=self.correct_bias,
593+
)
594+
return opt
595+
596+
498597
class OptimizerList:
499598
"""OptimizerList which wrap more than one optimizer.
500599
NOTE: LBFGS is not supported yet.

0 commit comments

Comments
 (0)