Skip to content

Commit ef9f0fe

Browse files
committed
fix: exp_avg_sq_row/col are located on the CPU, not grad.device
1 parent 19dcf2b commit ef9f0fe

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

pytorch_optimizer/optimizer/adafactor.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,10 @@ def reset(self):
8181
state['exp_avg'] = torch.zeros_like(p)
8282

8383
if factored:
84-
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype)
85-
state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:], dtype=grad.dtype)
84+
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype, device=grad.device)
85+
state['exp_avg_sq_col'] = torch.zeros(
86+
grad_shape[:-2] + grad_shape[-1:], dtype=grad.dtype, device=grad.device
87+
)
8688
else:
8789
state['exp_avg_sq'] = torch.zeros_like(grad)
8890

@@ -145,8 +147,10 @@ def step(self, closure: CLOSURE = None) -> LOSS:
145147
state['exp_avg'] = torch.zeros_like(p)
146148

147149
if factored:
148-
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype)
149-
state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:], dtype=grad.dtype)
150+
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype, device=grad.device)
151+
state['exp_avg_sq_col'] = torch.zeros(
152+
grad_shape[:-2] + grad_shape[-1:], dtype=grad.dtype, device=grad.device
153+
)
150154
else:
151155
state['exp_avg_sq'] = torch.zeros_like(grad)
152156

@@ -170,7 +174,6 @@ def step(self, closure: CLOSURE = None) -> LOSS:
170174
self.approximate_sq_grad(exp_avg_sq_row, exp_avg_sq_col, update)
171175
else:
172176
exp_avg_sq = state['exp_avg_sq']
173-
174177
exp_avg_sq.mul_(beta2_t).add_(update, alpha=1.0 - beta2_t)
175178
torch.rsqrt(exp_avg_sq, out=update)
176179

0 commit comments

Comments
 (0)