Skip to content

Commit 806f1f9

Browse files
committed
refactor: GC
1 parent 538963f commit 806f1f9

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

pytorch_optimizer/ranger.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
from torch.optim.optimizer import Optimizer
66

7+
from pytorch_optimizer.gc import centralize_gradient
78
from pytorch_optimizer.types import BETAS, BUFFER, CLOSURE, DEFAULTS, LOSS, PARAMETERS
89

910

@@ -117,8 +118,8 @@ def step(self, _: CLOSURE = None) -> LOSS:
117118
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
118119
beta1, beta2 = group['betas']
119120

120-
if grad.dim() > self.gc_gradient_threshold:
121-
grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True))
121+
if self.use_gc and grad.dim() > self.gc_gradient_threshold:
122+
grad = centralize_gradient(grad, gc_conv_only=False)
122123

123124
state['step'] += 1
124125

0 commit comments

Comments
 (0)