Skip to content

Commit 169ae96

Browse files
committed
fix: accumulator is located on the CPU, not grad.device
1 parent ef9f0fe commit 169ae96

File tree

1 file changed

+5
-3
lines changed
  • pytorch_optimizer/optimizer

1 file changed

+5
-3
lines changed

pytorch_optimizer/optimizer/sm3.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
9898
state['momentum_buffer'] = torch.zeros_like(p)
9999

100100
if grad.is_sparse:
101-
state['accumulator_0'] = torch.zeros(shape[0])
101+
state['accumulator_0'] = torch.zeros(shape[0], device=grad.device)
102102
elif rank == 0:
103-
state['accumulator_0'] = torch.zeros(shape)
103+
state['accumulator_0'] = torch.zeros_like(p)
104104
else:
105105
for i in range(rank):
106-
state[f'accumulator_{i}'] = torch.zeros([1] * i + [shape[i]] + [1] * (rank - 1 - i))
106+
state[f'accumulator_{i}'] = torch.zeros(
107+
[1] * i + [shape[i]] + [1] * (rank - 1 - i), device=grad.device
108+
)
107109

108110
state['step'] += 1
109111

0 commit comments

Comments
 (0)