Skip to content

Commit dea159a

Browse files
committed
update: ndim to dim()
1 parent 27db169 commit dea159a

File tree

1 file changed

+2
-8
lines changed

1 file changed

+2
-8
lines changed

pytorch_optimizer/optimizer/spam.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def update_masks(self) -> None:
155155
for group in self.param_groups:
156156
for p in group['params']:
157157
state = self.state[p]
158-
if p.ndim == 2 and 'mask' in state:
158+
if p.dim() == 2 and 'mask' in state:
159159
state['mask'] = self.update_mask_random(p, state['mask'])
160160
p.mask = state['mask']
161161

@@ -177,13 +177,7 @@ def __str__(self) -> str:
177177

178178
@torch.no_grad()
179179
def reset(self):
180-
for group in self.param_groups:
181-
group['step'] = 0
182-
for p in group['params']:
183-
state = self.state[p]
184-
185-
state['exp_avg'] = torch.zeros_like(p)
186-
state['exp_avg_sq'] = torch.zeros_like(p)
180+
pass
187181

188182
@torch.no_grad()
189183
def step(self, closure: CLOSURE = None) -> LOSS:

0 commit comments

Comments
 (0)