Skip to content

Commit 03e78cd

Browse files
committed
feature: support Gradient Centralization
1 parent 806f1f9 commit 03e78cd

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

pytorch_optimizer/adamp.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch.nn.functional as F
66
from torch.optim.optimizer import Optimizer
77

8+
from pytorch_optimizer.gc import centralize_gradient
89
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
910

1011

@@ -32,6 +33,7 @@ def __init__(
3233
weight_decay: float = 0.0,
3334
delta: float = 0.1,
3435
wd_ratio: float = 0.1,
36+
use_gc: bool = False,
3537
nesterov: bool = False,
3638
eps: float = 1e-8,
3739
):
@@ -43,13 +45,15 @@ def __init__(
4345
:param delta: float. threshold that determines whether a set of parameters is scale invariant or not
4446
:param wd_ratio: float. relative weight decay applied on scale-invariant parameters compared to that applied
4547
on scale-variant parameters
48+
:param use_gc: bool. use gradient centralization
4649
:param nesterov: bool. enables Nesterov momentum
4750
:param eps: float. term added to the denominator to improve numerical stability
4851
"""
4952
self.lr = lr
5053
self.betas = betas
5154
self.weight_decay = weight_decay
5255
self.wd_ratio = wd_ratio
56+
self.use_gc = use_gc
5357
self.eps = eps
5458

5559
self.check_valid_parameters()
@@ -146,6 +150,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
146150

147151
grad = p.grad.data
148152

153+
if self.use_gc:
154+
grad = centralize_gradient(grad, gc_conv_only=False)
155+
149156
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
150157
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
151158

0 commit comments

Comments
 (0)