Skip to content

Commit 96d6ac6

Browse files
committed
fix: typo
1 parent e0acc05 commit 96d6ac6

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pytorch_optimizer/optimizer/adalite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def reset(self):
7171
state['m_avg'] = torch.zeros_like(p)
7272
state['v_avg'] = torch.zeros_like(p)
7373
else:
74-
state['v_avg_0'] = torch.zeros_like(p.shape(dim=1))
74+
state['v_avg_0'] = torch.zeros_like(p.mean(dim=1))
7575
state['v_avg_1'] = torch.zeros_like(p.mean(dim=0))
7676

7777
state['m_avg_c'] = torch.zeros_like(p.mean(dim=1)[:, None])
@@ -108,7 +108,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
108108
state['m_avg'] = torch.zeros_like(p)
109109
state['v_avg'] = torch.zeros_like(p)
110110
else:
111-
state['v_avg_0'] = torch.zeros_like(p.shape(dim=1))
111+
state['v_avg_0'] = torch.zeros_like(p.mean(dim=1))
112112
state['v_avg_1'] = torch.zeros_like(p.mean(dim=0))
113113

114114
state['m_avg_c'] = torch.zeros_like(p.mean(dim=1)[:, None])

0 commit comments

Comments
 (0)