Skip to content

Commit 487200d

Browse files
authored
Merge pull request #365 from kozistr/fix/spam-optimizer
[Fix] potential bug in SPAM optimizer
2 parents 3a28bae + c7b2b0a commit 487200d

File tree

14 files changed

+35
-65
lines changed

14 files changed

+35
-65
lines changed

docs/changelogs/v3.4.3.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@
2020
### Fix
2121

2222
* bias_correction2 in ScheduleFreeRAdam optimizer. (#354)
23+
* potential bug in SPAM optimizer. (#365)

pytorch_optimizer/optimizer/experimental/ranger25.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
169169
state = self.state[p]
170170

171171
if len(state) == 0:
172-
state['exp_avg'] = torch.zeros_like(p)
173-
state['exp_avg_sq'] = torch.zeros_like(p)
174-
state['exp_avg_slow'] = torch.zeros_like(p)
172+
state['exp_avg'] = torch.zeros_like(grad)
173+
state['exp_avg_sq'] = torch.zeros_like(grad)
174+
state['exp_avg_slow'] = torch.zeros_like(grad)
175175
state['slow_momentum'] = p.clone()
176176

177177
self.apply_weight_decay(

pytorch_optimizer/optimizer/sgd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
389389
state = self.state[p]
390390
if momentum > 0.0:
391391
if len(state) == 0:
392-
state['momentum_buffer'] = torch.zeros_like(p)
392+
state['momentum_buffer'] = torch.zeros_like(grad)
393393

394394
buf = state['momentum_buffer']
395395
buf.mul_(momentum).add_(grad, alpha=1.0 - momentum)

pytorch_optimizer/optimizer/sgdp.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,7 @@ def __str__(self) -> str:
6262

6363
@torch.no_grad()
6464
def reset(self):
65-
for group in self.param_groups:
66-
for p in group['params']:
67-
state = self.state[p]
68-
69-
state['momentum'] = torch.zeros_like(p)
65+
pass
7066

7167
@torch.no_grad()
7268
def step(self, closure: CLOSURE = None) -> LOSS:
@@ -87,7 +83,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
8783

8884
state = self.state[p]
8985
if len(state) == 0:
90-
state['momentum'] = torch.zeros_like(p)
86+
state['momentum'] = torch.zeros_like(grad)
9187

9288
buf = state['momentum']
9389
buf.mul_(momentum).add_(grad, alpha=1.0 - group['dampening'])

pytorch_optimizer/optimizer/shampoo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
303303

304304
state = self.state[p]
305305
if len(state) == 0:
306-
state['momentum'] = torch.zeros_like(p)
306+
state['momentum'] = torch.zeros_like(grad)
307307
state['pre_conditioner'] = PreConditioner(
308308
p,
309309
beta2,

pytorch_optimizer/optimizer/sm3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,12 @@ def step(self, closure: CLOSURE = None) -> LOSS:
9191
state = self.state[p]
9292
if len(state) == 0:
9393
state['step'] = 0
94-
state['momentum_buffer'] = torch.zeros_like(p)
94+
state['momentum_buffer'] = torch.zeros_like(grad)
9595

9696
if grad.is_sparse:
9797
state['accumulator_0'] = torch.zeros(shape[0], dtype=grad.dtype, device=grad.device)
9898
elif rank == 0:
99-
state['accumulator_0'] = torch.zeros_like(p)
99+
state['accumulator_0'] = torch.zeros_like(grad)
100100
else:
101101
for i in range(rank):
102102
state[f'accumulator_{i}'] = torch.zeros(

pytorch_optimizer/optimizer/soap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def get_orthogonal_matrix_qr(self, state, max_precondition_dim: int = 10000, mer
161161
# Compute QR decomposition
162162
# We cast to float32 because:
163163
# - torch.linalg.qr does not have support for types like bfloat16 as of PyTorch 2.5.1
164-
# - the correctness / numerical stability of the Q orthogonalization is important for the stability
164+
# - the correctness / numerical stability of the Q orthogonality is important for the stability
165165
# of the optimizer
166166
q, _ = torch.linalg.qr(power_iter.to(torch.float32))
167167
q = q.to(power_iter.dtype)

pytorch_optimizer/optimizer/sophia.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ def step(self, closure: CLOSURE = None, hessian: Optional[List[torch.Tensor]] =
113113

114114
state = self.state[p]
115115
if len(state) == 0:
116-
state['momentum'] = torch.zeros_like(p)
117-
state['hessian_moment'] = torch.zeros_like(p)
116+
state['momentum'] = torch.zeros_like(grad)
117+
state['hessian_moment'] = torch.zeros_like(grad)
118118

119119
self.apply_weight_decay(
120120
p=p,

pytorch_optimizer/optimizer/spam.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(
6868
betas: BETAS = (0.9, 0.999),
6969
density: float = 1.0,
7070
weight_decay: float = 0.0,
71-
warmup_epoch: int = 150,
71+
warmup_epoch: int = 50,
7272
threshold: int = 5000,
7373
grad_accu_steps: int = 20,
7474
update_proj_gap: int = 500,
@@ -90,11 +90,12 @@ def __init__(
9090
self.threshold = threshold
9191
self.grad_accu_steps = grad_accu_steps
9292
self.update_proj_gap = update_proj_gap
93-
self.warmup = CosineDecay(0.99, warmup_epoch)
9493

9594
defaults: DEFAULTS = {'lr': lr, 'betas': betas, 'weight_decay': weight_decay, 'eps': eps, **kwargs}
9695
super().__init__(params, defaults)
9796

97+
self.warmup = CosineDecay(0.99, self.warmup_epoch)
98+
9899
self.init_masks()
99100

100101
self.state['total_step'] = 0
@@ -119,17 +120,16 @@ def initialize_random_rank_boolean_tensor(m: int, n: int, density: float, device
119120

120121
return tensor.view(m, n)
121122

122-
def update_mask_random(self, density: float, p: torch.Tensor, old_mask: torch.Tensor) -> torch.Tensor:
123+
def update_mask_random(self, p: torch.Tensor, old_mask: torch.Tensor) -> torch.Tensor:
123124
r"""Update a random mask.
124125
125126
Create a new random mask with the same density, compute overlap ratio with old_mask, and update the EMA for
126127
the overlap region.
127128
128-
:param density: float. fraction of elements to keep.
129129
:param p: torch.Tensor. parameter to which the mask is applied.
130130
:param old_mask: torch.Tensor. previous binary mask.
131131
"""
132-
new_mask: torch.Tensor = torch.rand_like(p) < density
132+
new_mask: torch.Tensor = torch.rand_like(p) < self.density
133133

134134
exp_avg = torch.zeros_like(p[new_mask])
135135
exp_avg_sq = torch.zeros_like(p[new_mask])
@@ -155,8 +155,8 @@ def update_masks(self) -> None:
155155
for group in self.param_groups:
156156
for p in group['params']:
157157
state = self.state[p]
158-
if 'mask' in state:
159-
state['mask'] = self.update_mask_random(self.density, p, state['mask'])
158+
if p.dim() == 2 and 'mask' in state:
159+
state['mask'] = self.update_mask_random(p, state['mask'])
160160
p.mask = state['mask']
161161

162162
def init_masks(self) -> None:
@@ -177,13 +177,7 @@ def __str__(self) -> str:
177177

178178
@torch.no_grad()
179179
def reset(self):
180-
for group in self.param_groups:
181-
group['step'] = 0
182-
for p in group['params']:
183-
state = self.state[p]
184-
185-
state['exp_avg'] = torch.zeros_like(p)
186-
state['exp_avg_sq'] = torch.zeros_like(p)
180+
pass
187181

188182
@torch.no_grad()
189183
def step(self, closure: CLOSURE = None) -> LOSS:
@@ -220,11 +214,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
220214
if 'mask' in state:
221215
grad = grad[state['mask']]
222216

223-
if 'exp_avg' not in state:
224-
state['exp_avg'] = torch.zeros_like(grad)
225-
state['exp_avg_sq'] = torch.zeros_like(grad)
226-
227-
if (self.state['total_step'] + 1) % self.update_proj_gap == 0:
217+
if ('exp_avg' not in state) or (self.state['total_step'] + 1) % self.update_proj_gap == 0:
228218
state['exp_avg'] = torch.zeros_like(grad)
229219
state['exp_avg_sq'] = torch.zeros_like(grad)
230220

pytorch_optimizer/optimizer/srmm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
7272

7373
state = self.state[p]
7474
if len(state) == 0:
75-
state['mov_avg_grad'] = torch.zeros_like(p)
76-
state['mov_avg_param'] = torch.zeros_like(p)
75+
state['mov_avg_grad'] = torch.zeros_like(grad)
76+
state['mov_avg_param'] = torch.zeros_like(grad)
7777

7878
mov_avg_grad, mov_avg_param = state['mov_avg_grad'], state['mov_avg_param']
7979

0 commit comments

Comments
 (0)