Skip to content

Commit ba341db

Browse files
committed
update: agc
1 parent ed61369 commit ba341db

File tree

1 file changed

+1
-3
lines changed
  • pytorch_optimizer/optimizer

1 file changed

+1
-3
lines changed

pytorch_optimizer/optimizer/agc.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,9 @@ def agc(
1414
:param agc_clip_val: float. norm clip.
1515
:param eps: float. simple stop from div by zero and no relation to standard optimizer eps.
1616
"""
17-
p_norm = unit_norm(p).clamp_min_(agc_eps)
17+
max_norm = unit_norm(p).clamp_min_(agc_eps).mul_(agc_clip_val)
1818
g_norm = unit_norm(grad).clamp_min_(eps)
1919

20-
max_norm = p_norm * agc_clip_val
21-
2220
clipped_grad = grad * (max_norm / g_norm)
2321

2422
return torch.where(g_norm > max_norm, clipped_grad, grad)

0 commit comments

Comments
 (0)