Skip to content

Commit 3172b69

Browse files
committed
update: initialization
1 parent 7a8cf29 commit 3172b69

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

pytorch_optimizer/optimizer/adafactor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,10 @@ def step(self, closure: CLOSURE = None) -> LOSS:
145145
state['exp_avg'] = torch.zeros_like(p)
146146

147147
if factored:
148-
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad)
149-
state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
148+
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype, device=grad.device)
149+
state['exp_avg_sq_col'] = torch.zeros(
150+
grad_shape[:-2] + grad_shape[-1:], dtype=grad.dtype, device=grad.device
151+
)
150152
else:
151153
state['exp_avg_sq'] = torch.zeros_like(grad)
152154

pytorch_optimizer/optimizer/sm3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,13 @@ def step(self, closure: CLOSURE = None) -> LOSS:
9898
state['momentum_buffer'] = torch.zeros_like(p)
9999

100100
if grad.is_sparse:
101-
state['accumulator_0'] = torch.zeros(shape[0], device=grad.device)
101+
state['accumulator_0'] = torch.zeros(shape[0], dtype=grad.dtype, device=grad.device)
102102
elif rank == 0:
103103
state['accumulator_0'] = torch.zeros_like(p)
104104
else:
105105
for i in range(rank):
106106
state[f'accumulator_{i}'] = torch.zeros(
107-
[1] * i + [shape[i]] + [1] * (rank - 1 - i), device=grad.device
107+
[1] * i + [shape[i]] + [1] * (rank - 1 - i), dtype=grad.dtype, device=grad.device
108108
)
109109

110110
state['step'] += 1

0 commit comments

Comments
 (0)