55import torch .nn .functional as F
66from torch .optim .optimizer import Optimizer
77
8+ from pytorch_optimizer .gc import centralize_gradient
89from pytorch_optimizer .types import BETAS , CLOSURE , DEFAULTS , LOSS , PARAMETERS
910
1011
@@ -32,6 +33,7 @@ def __init__(
3233 weight_decay : float = 0.0 ,
3334 delta : float = 0.1 ,
3435 wd_ratio : float = 0.1 ,
36+ use_gc : bool = False ,
3537 nesterov : bool = False ,
3638 eps : float = 1e-8 ,
3739 ):
@@ -43,13 +45,15 @@ def __init__(
4345 :param delta: float. threshold that determines whether a set of parameters is scale invariant or not
4446 :param wd_ratio: float. relative weight decay applied on scale-invariant parameters compared to that applied
4547 on scale-variant parameters
48+ :param use_gc: bool. use gradient centralization
4649 :param nesterov: bool. enables Nesterov momentum
4750 :param eps: float. term added to the denominator to improve numerical stability
4851 """
4952 self .lr = lr
5053 self .betas = betas
5154 self .weight_decay = weight_decay
5255 self .wd_ratio = wd_ratio
56+ self .use_gc = use_gc
5357 self .eps = eps
5458
5559 self .check_valid_parameters ()
@@ -146,6 +150,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
146150
147151 grad = p .grad .data
148152
153+ if self .use_gc :
154+ grad = centralize_gradient (grad , gc_conv_only = False )
155+
149156 exp_avg .mul_ (beta1 ).add_ (grad , alpha = 1.0 - beta1 )
150157 exp_avg_sq .mul_ (beta2 ).addcmul_ (grad , grad , value = 1 - beta2 )
151158
0 commit comments