|
27 | 27 | from paddle.incubate import optimizer as incubate_optim
|
28 | 28 | from typing_extensions import Literal
|
29 | 29 |
|
| 30 | +from ppsci.optimizer.soap import SOAP as SOAP_impl |
30 | 31 | from ppsci.utils import logger
|
31 | 32 | from ppsci.utils import misc
|
32 | 33 |
|
@@ -495,6 +496,104 @@ def _apply_decay_param_fun(self, name):
|
495 | 496 | return name not in self.no_weight_decay_param_name_list
|
496 | 497 |
|
497 | 498 |
|
| 499 | +class SOAP: |
| 500 | + """ |
| 501 | + Improving and Stabilizing Shampoo using Adam. Implements SOAP algorithm (https://arxiv.org/abs/2409.11321). |
| 502 | +
|
| 503 | + Args: |
| 504 | + learning_rate (float, optional): |
| 505 | + The learning rate to use. defaults to 0.003. |
| 506 | + beta1 (float, optional): |
| 507 | + Adam's betas parameters beta1. defaults to 0.95. |
| 508 | + beta2 (float, optional): |
| 509 | + Adam's betas parameters beta2. defaults to 0.95. |
| 510 | + shampoo_beta (float, optional): |
| 511 | + If >= 0, use this beta for the preconditioner (L and R in paper, state['GG'] below) moving average instead of betas[1]. |
| 512 | + defaults to -1. |
| 513 | + epsilon (float, optional): |
| 514 | + Adam's epsilon for numerical stability. defaults to 1e-08. |
| 515 | + weight_decay (float, optional): weight decay coefficient. defaults to 0.01. |
| 516 | + precondition_frequency (int, optional): |
| 517 | + How often to update the preconditioner. defaults to 10. |
| 518 | + max_precond_dim (int, optional): |
| 519 | + Maximum dimension of the preconditioner. |
| 520 | + Set to 10000, so that we exclude most common vocab sizes while including layers. defaults to 10000. |
| 521 | + merge_dims (bool, optional): |
| 522 | + Whether or not to merge dimensions of the preconditioner. defaults to `False`. |
| 523 | + precondition_1d (bool, optional): |
| 524 | + Whether or not to precondition 1D gradients. defaults to `False`. |
| 525 | + normalize_grads (bool, optional): |
| 526 | + Whether or not to normalize gradients per layer. |
| 527 | + Helps at large precondition_frequency (~100 in our experiments), |
| 528 | + but hurts performance at small precondition_frequency (~10 in our experiments). defaults to `False`. |
| 529 | + data_format (str, optional): |
| 530 | + Data format of the input for convolutional layers. |
| 531 | + Should be "channels_last" for data_format of NHWC and "channels_first" for NCHW. defaults to `channels_first`. |
| 532 | + correct_bias (bool, optional): |
| 533 | + Whether or not to use bias correction in Adam. defaults to `True`. |
| 534 | +
|
| 535 | + Examples: |
| 536 | + >>> import ppsci |
| 537 | + >>> model = ppsci.arch.MLP(("x",), ("u",), 5, 20) |
| 538 | + >>> opt = ppsci.optimizer.SOAP(1e-3)(model) |
| 539 | + """ |
| 540 | + |
| 541 | + def __init__( |
| 542 | + self, |
| 543 | + learning_rate: float = 3e-3, |
| 544 | + beta1: float = 0.95, |
| 545 | + beta2: float = 0.95, |
| 546 | + shampoo_beta: float = -1, |
| 547 | + epsilon: float = 1e-8, |
| 548 | + weight_decay: float = 0.01, |
| 549 | + precondition_frequency: int = 10, |
| 550 | + max_precond_dim: int = 10000, # |
| 551 | + merge_dims: bool = False, # Merge dimensions till the product of the dimensions is less than or equal to max_precond_dim. |
| 552 | + precondition_1d: bool = False, |
| 553 | + normalize_grads: bool = False, |
| 554 | + data_format: str = "channels_first", |
| 555 | + correct_bias: bool = True, |
| 556 | + ): |
| 557 | + self.learning_rate = learning_rate |
| 558 | + self.beta1 = beta1 |
| 559 | + self.beta2 = beta2 |
| 560 | + self.shampoo_beta = shampoo_beta |
| 561 | + self.epsilon = epsilon |
| 562 | + self.weight_decay = weight_decay |
| 563 | + self.precondition_frequency = precondition_frequency |
| 564 | + self.max_precond_dim = max_precond_dim |
| 565 | + self.merge_dims = merge_dims |
| 566 | + self.precondition_1d = precondition_1d |
| 567 | + self.normalize_grads = normalize_grads |
| 568 | + self.data_format = data_format |
| 569 | + self.correct_bias = correct_bias |
| 570 | + |
| 571 | + def __call__(self, model_list: Union[nn.Layer, Tuple[nn.Layer, ...]]): |
| 572 | + # model_list is None in static graph |
| 573 | + if not isinstance(model_list, (tuple, list)): |
| 574 | + model_list = (model_list,) |
| 575 | + parameters = ( |
| 576 | + sum([m.parameters() for m in model_list], []) if model_list else None |
| 577 | + ) |
| 578 | + opt = SOAP_impl( |
| 579 | + parameters=parameters, |
| 580 | + learning_rate=self.learning_rate, |
| 581 | + beta1=self.beta1, |
| 582 | + beta2=self.beta2, |
| 583 | + shampoo_beta=self.shampoo_beta, |
| 584 | + epsilon=self.epsilon, |
| 585 | + weight_decay=self.weight_decay, |
| 586 | + precondition_frequency=self.precondition_frequency, |
| 587 | + max_precond_dim=self.max_precond_dim, |
| 588 | + merge_dims=self.merge_dims, |
| 589 | + precondition_1d=self.precondition_1d, |
| 590 | + normalize_grads=self.normalize_grads, |
| 591 | + data_format=self.data_format, |
| 592 | + correct_bias=self.correct_bias, |
| 593 | + ) |
| 594 | + return opt |
| 595 | + |
| 596 | + |
498 | 597 | class OptimizerList:
|
499 | 598 | """OptimizerList which wrap more than one optimizer.
|
500 | 599 | NOTE: LBFGS is not supported yet.
|
|
0 commit comments