Skip to content

Commit ba85681

Browse files
committed
refactor: nan_to_num()
1 parent 6bd8afc commit ba85681

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

pytorch_optimizer/optimizer/nero.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
9292
bias_correction: float = 1.0 - self.beta ** state['step']
9393

9494
grad_normed = grad / ((exp_avg_sq / bias_correction).sqrt() + self.eps)
95-
grad_normed[torch.isnan(grad_normed)] = 0.0
95+
torch.nan_to_num(grad_normed, nan=0.0, out=grad_normed)
9696

9797
p.sub_(group['lr'] * state['scale'] * grad_normed)
9898

0 commit comments

Comments
 (0)