We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent ef9f0fe commit 169ae96Copy full SHA for 169ae96
pytorch_optimizer/optimizer/sm3.py
@@ -98,12 +98,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
98
state['momentum_buffer'] = torch.zeros_like(p)
99
100
if grad.is_sparse:
101
- state['accumulator_0'] = torch.zeros(shape[0])
+ state['accumulator_0'] = torch.zeros(shape[0], device=grad.device)
102
elif rank == 0:
103
- state['accumulator_0'] = torch.zeros(shape)
+ state['accumulator_0'] = torch.zeros_like(p)
104
else:
105
for i in range(rank):
106
- state[f'accumulator_{i}'] = torch.zeros([1] * i + [shape[i]] + [1] * (rank - 1 - i))
+ state[f'accumulator_{i}'] = torch.zeros(
107
+ [1] * i + [shape[i]] + [1] * (rank - 1 - i), device=grad.device
108
+ )
109
110
state['step'] += 1
111
0 commit comments