Skip to content

Commit b0124e0

Browse files
authored
Fix SRMM to allow operation beyond memory_length
1 parent da65344 commit b0124e0

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

pytorch_optimizer/optimizer/srmm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
5353
group['step'] = 1
5454

5555
w_t: float = (
56-
(group['step'] + 1) % (group['memory_length'] if group['memory_length'] is not None else 1)
56+
(group['step'] % (group['memory_length'] if group['memory_length'] is not None else 1)) + 1
5757
) ** -group['beta']
5858

5959
for p in group['params']:

0 commit comments

Comments
 (0)