We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent ed61369 commit ba341dbCopy full SHA for ba341db
pytorch_optimizer/optimizer/agc.py
@@ -14,11 +14,9 @@ def agc(
14
:param agc_clip_val: float. norm clip.
15
:param eps: float. simple stop from div by zero and no relation to standard optimizer eps.
16
"""
17
- p_norm = unit_norm(p).clamp_min_(agc_eps)
+ max_norm = unit_norm(p).clamp_min_(agc_eps).mul_(agc_clip_val)
18
g_norm = unit_norm(grad).clamp_min_(eps)
19
20
- max_norm = p_norm * agc_clip_val
21
-
22
clipped_grad = grad * (max_norm / g_norm)
23
24
return torch.where(g_norm > max_norm, clipped_grad, grad)
0 commit comments