Skip to content

Commit c8fdf41

Browse files
committed
update: reset
1 parent 7b59b46 commit c8fdf41

File tree

1 file changed

+2
-1
lines changed
  • pytorch_optimizer/optimizer

1 file changed

+2
-1
lines changed

pytorch_optimizer/optimizer/sm3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def reset(self):
5050
for p in group['params']:
5151
state = self.state[p]
5252

53-
state['momentum_buffer'] = 0.0
53+
state['step'] = 0
54+
state['momentum_buffer'] = torch.zeros_like(p)
5455

5556
@staticmethod
5657
def max_reduce_except_dim(x: torch.Tensor, dim: int) -> torch.Tensor:

0 commit comments

Comments
 (0)