Skip to content

Commit 27db169

Browse files
committed
update: p to grad
1 parent 64f3412 commit 27db169

File tree

11 files changed

+22
-43
lines changed

11 files changed

+22
-43
lines changed

pytorch_optimizer/optimizer/sgd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
389389
state = self.state[p]
390390
if momentum > 0.0:
391391
if len(state) == 0:
392-
state['momentum_buffer'] = torch.zeros_like(p)
392+
state['momentum_buffer'] = torch.zeros_like(grad)
393393

394394
buf = state['momentum_buffer']
395395
buf.mul_(momentum).add_(grad, alpha=1.0 - momentum)

pytorch_optimizer/optimizer/sgdp.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,7 @@ def __str__(self) -> str:
6262

6363
@torch.no_grad()
6464
def reset(self):
65-
for group in self.param_groups:
66-
for p in group['params']:
67-
state = self.state[p]
68-
69-
state['momentum'] = torch.zeros_like(p)
65+
pass
7066

7167
@torch.no_grad()
7268
def step(self, closure: CLOSURE = None) -> LOSS:
@@ -87,7 +83,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
8783

8884
state = self.state[p]
8985
if len(state) == 0:
90-
state['momentum'] = torch.zeros_like(p)
86+
state['momentum'] = torch.zeros_like(grad)
9187

9288
buf = state['momentum']
9389
buf.mul_(momentum).add_(grad, alpha=1.0 - group['dampening'])

pytorch_optimizer/optimizer/shampoo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
303303

304304
state = self.state[p]
305305
if len(state) == 0:
306-
state['momentum'] = torch.zeros_like(p)
306+
state['momentum'] = torch.zeros_like(grad)
307307
state['pre_conditioner'] = PreConditioner(
308308
p,
309309
beta2,

pytorch_optimizer/optimizer/sm3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,12 @@ def step(self, closure: CLOSURE = None) -> LOSS:
9191
state = self.state[p]
9292
if len(state) == 0:
9393
state['step'] = 0
94-
state['momentum_buffer'] = torch.zeros_like(p)
94+
state['momentum_buffer'] = torch.zeros_like(grad)
9595

9696
if grad.is_sparse:
9797
state['accumulator_0'] = torch.zeros(shape[0], dtype=grad.dtype, device=grad.device)
9898
elif rank == 0:
99-
state['accumulator_0'] = torch.zeros_like(p)
99+
state['accumulator_0'] = torch.zeros_like(grad)
100100
else:
101101
for i in range(rank):
102102
state[f'accumulator_{i}'] = torch.zeros(

pytorch_optimizer/optimizer/soap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def get_orthogonal_matrix_qr(self, state, max_precondition_dim: int = 10000, mer
161161
# Compute QR decomposition
162162
# We cast to float32 because:
163163
# - torch.linalg.qr does not have support for types like bfloat16 as of PyTorch 2.5.1
164-
# - the correctness / numerical stability of the Q orthogonalization is important for the stability
164+
# - the correctness / numerical stability of the Q orthogonality is important for the stability
165165
# of the optimizer
166166
q, _ = torch.linalg.qr(power_iter.to(torch.float32))
167167
q = q.to(power_iter.dtype)

pytorch_optimizer/optimizer/sophia.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ def step(self, closure: CLOSURE = None, hessian: Optional[List[torch.Tensor]] =
113113

114114
state = self.state[p]
115115
if len(state) == 0:
116-
state['momentum'] = torch.zeros_like(p)
117-
state['hessian_moment'] = torch.zeros_like(p)
116+
state['momentum'] = torch.zeros_like(grad)
117+
state['hessian_moment'] = torch.zeros_like(grad)
118118

119119
self.apply_weight_decay(
120120
p=p,

pytorch_optimizer/optimizer/srmm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
7272

7373
state = self.state[p]
7474
if len(state) == 0:
75-
state['mov_avg_grad'] = torch.zeros_like(p)
76-
state['mov_avg_param'] = torch.zeros_like(p)
75+
state['mov_avg_grad'] = torch.zeros_like(grad)
76+
state['mov_avg_param'] = torch.zeros_like(grad)
7777

7878
mov_avg_grad, mov_avg_param = state['mov_avg_grad'], state['mov_avg_param']
7979

pytorch_optimizer/optimizer/swats.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,17 +110,17 @@ def step(self, closure: CLOSURE = None) -> LOSS:
110110
state = self.state[p]
111111

112112
if len(state) == 0:
113-
state['exp_avg'] = torch.zeros_like(p)
114-
state['exp_avg_sq'] = torch.zeros_like(p)
113+
state['exp_avg'] = torch.zeros_like(grad)
114+
state['exp_avg_sq'] = torch.zeros_like(grad)
115115
state['exp_avg2'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)
116116
if group['ams_bound']:
117-
state['max_exp_avg_sq'] = torch.zeros_like(p)
117+
state['max_exp_avg_sq'] = torch.zeros_like(grad)
118118
if group['adanorm']:
119119
state['exp_grad_norm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)
120120

121121
self.apply_weight_decay(
122122
p=p,
123-
grad=p.grad,
123+
grad=grad,
124124
lr=group['lr'],
125125
weight_decay=group['weight_decay'],
126126
weight_decouple=group['weight_decouple'],

pytorch_optimizer/optimizer/tam.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,7 @@ def __str__(self) -> str:
5454

5555
@torch.no_grad()
5656
def reset(self):
57-
for group in self.param_groups:
58-
group['step'] = 0
59-
for p in group['params']:
60-
state = self.state[p]
61-
62-
state['s'] = torch.zeros_like(p)
63-
state['momentum_buffer'] = torch.zeros_like(p)
57+
pass
6458

6559
@torch.no_grad()
6660
def step(self, closure: CLOSURE = None) -> LOSS:
@@ -157,14 +151,7 @@ def __str__(self) -> str:
157151

158152
@torch.no_grad()
159153
def reset(self):
160-
for group in self.param_groups:
161-
group['step'] = 0
162-
for p in group['params']:
163-
state = self.state[p]
164-
165-
state['s'] = torch.zeros_like(p)
166-
state['exp_avg'] = torch.zeros_like(p)
167-
state['exp_avg_sq'] = torch.zeros_like(p)
154+
pass
168155

169156
@torch.no_grad()
170157
def step(self, closure: CLOSURE = None) -> LOSS:

pytorch_optimizer/optimizer/tiger.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,7 @@ def __str__(self) -> str:
4545

4646
@torch.no_grad()
4747
def reset(self):
48-
for group in self.param_groups:
49-
for p in group['params']:
50-
state = self.state[p]
51-
52-
state['exp_avg'] = torch.zeros_like(p)
48+
pass
5349

5450
@torch.no_grad()
5551
def step(self, closure: CLOSURE = None) -> LOSS:
@@ -71,7 +67,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
7167
state = self.state[p]
7268

7369
if len(state) == 0:
74-
state['exp_avg'] = torch.zeros_like(p)
70+
state['exp_avg'] = torch.zeros_like(grad)
7571

7672
self.apply_weight_decay(
7773
p=p,

0 commit comments

Comments
 (0)