Skip to content

Commit 7a8cf29

Browse files
committed
fix: approximate_sq_grad
1 parent fb45a20 commit 7a8cf29

File tree

1 file changed

+7
-12
lines changed

1 file changed

+7
-12
lines changed

pytorch_optimizer/optimizer/adafactor.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,8 @@ def reset(self):
8181
state['exp_avg'] = torch.zeros_like(p)
8282

8383
if factored:
84-
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype, device=grad.device)
85-
state['exp_avg_sq_col'] = torch.zeros(
86-
grad_shape[:-2] + grad_shape[-1:], dtype=grad.dtype, device=grad.device
87-
)
84+
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad)
85+
state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
8886
else:
8987
state['exp_avg_sq'] = torch.zeros_like(grad)
9088

@@ -116,8 +114,8 @@ def approximate_sq_grad(
116114
exp_avg_sq_col: torch.Tensor,
117115
output: torch.Tensor,
118116
):
119-
r"""Get approximate squared gradient."""
120-
r_factor: torch.Tensor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1)).rsqrt_().unsqueeze(-1)
117+
r"""Get approximation of EMA of squared gradient."""
118+
r_factor: torch.Tensor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
121119
c_factor: torch.Tensor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
122120
torch.mul(r_factor, c_factor, out=output)
123121

@@ -147,10 +145,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
147145
state['exp_avg'] = torch.zeros_like(p)
148146

149147
if factored:
150-
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype, device=grad.device)
151-
state['exp_avg_sq_col'] = torch.zeros(
152-
grad_shape[:-2] + grad_shape[-1:], dtype=grad.dtype, device=grad.device
153-
)
148+
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad)
149+
state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
154150
else:
155151
state['exp_avg_sq'] = torch.zeros_like(grad)
156152

@@ -170,8 +166,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
170166
exp_avg_sq_row.mul_(beta2_t).add_(update.mean(dim=-1), alpha=1.0 - beta2_t)
171167
exp_avg_sq_col.mul_(beta2_t).add_(update.mean(dim=-2), alpha=1.0 - beta2_t)
172168

173-
# Approximation of exponential moving average of square of gradient
174-
self.approximate_sq_grad(exp_avg_sq_row, exp_avg_sq_col, update)
169+
self.approximate_sq_grad(exp_avg_sq_row, exp_avg_sq_col, output=update)
175170
else:
176171
exp_avg_sq = state['exp_avg_sq']
177172
exp_avg_sq.mul_(beta2_t).add_(update, alpha=1.0 - beta2_t)

0 commit comments

Comments
 (0)