Skip to content

Commit be0351d

Browse files
authored
Merge pull request #132 from kozistr/fix/device
[Fix] variables are not located on the same device with the gradients
2 parents 19dcf2b + 3172b69 commit be0351d

File tree

3 files changed

+15
-13
lines changed

3 files changed

+15
-13
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "2.6.0"
3+
version = "2.6.1"
44
description = "optimizer & lr scheduler implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas."
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]

pytorch_optimizer/optimizer/adafactor.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def reset(self):
8181
state['exp_avg'] = torch.zeros_like(p)
8282

8383
if factored:
84-
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype)
85-
state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:], dtype=grad.dtype)
84+
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad)
85+
state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
8686
else:
8787
state['exp_avg_sq'] = torch.zeros_like(grad)
8888

@@ -114,8 +114,8 @@ def approximate_sq_grad(
114114
exp_avg_sq_col: torch.Tensor,
115115
output: torch.Tensor,
116116
):
117-
r"""Get approximate squared gradient."""
118-
r_factor: torch.Tensor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1)).rsqrt_().unsqueeze(-1)
117+
r"""Get approximation of EMA of squared gradient."""
118+
r_factor: torch.Tensor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
119119
c_factor: torch.Tensor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
120120
torch.mul(r_factor, c_factor, out=output)
121121

@@ -145,8 +145,10 @@ def step(self, closure: CLOSURE = None) -> LOSS:
145145
state['exp_avg'] = torch.zeros_like(p)
146146

147147
if factored:
148-
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype)
149-
state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:], dtype=grad.dtype)
148+
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype, device=grad.device)
149+
state['exp_avg_sq_col'] = torch.zeros(
150+
grad_shape[:-2] + grad_shape[-1:], dtype=grad.dtype, device=grad.device
151+
)
150152
else:
151153
state['exp_avg_sq'] = torch.zeros_like(grad)
152154

@@ -166,11 +168,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
166168
exp_avg_sq_row.mul_(beta2_t).add_(update.mean(dim=-1), alpha=1.0 - beta2_t)
167169
exp_avg_sq_col.mul_(beta2_t).add_(update.mean(dim=-2), alpha=1.0 - beta2_t)
168170

169-
# Approximation of exponential moving average of square of gradient
170-
self.approximate_sq_grad(exp_avg_sq_row, exp_avg_sq_col, update)
171+
self.approximate_sq_grad(exp_avg_sq_row, exp_avg_sq_col, output=update)
171172
else:
172173
exp_avg_sq = state['exp_avg_sq']
173-
174174
exp_avg_sq.mul_(beta2_t).add_(update, alpha=1.0 - beta2_t)
175175
torch.rsqrt(exp_avg_sq, out=update)
176176

pytorch_optimizer/optimizer/sm3.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
9898
state['momentum_buffer'] = torch.zeros_like(p)
9999

100100
if grad.is_sparse:
101-
state['accumulator_0'] = torch.zeros(shape[0])
101+
state['accumulator_0'] = torch.zeros(shape[0], dtype=grad.dtype, device=grad.device)
102102
elif rank == 0:
103-
state['accumulator_0'] = torch.zeros(shape)
103+
state['accumulator_0'] = torch.zeros_like(p)
104104
else:
105105
for i in range(rank):
106-
state[f'accumulator_{i}'] = torch.zeros([1] * i + [shape[i]] + [1] * (rank - 1 - i))
106+
state[f'accumulator_{i}'] = torch.zeros(
107+
[1] * i + [shape[i]] + [1] * (rank - 1 - i), dtype=grad.dtype, device=grad.device
108+
)
107109

108110
state['step'] += 1
109111

0 commit comments

Comments
 (0)