Skip to content

Commit 3a28bae

Browse files
authored
Merge pull request #362 from kozistr/update/scion-optimizer
[Update] SCION optimizer
2 parents f1074de + b40db48 commit 3a28bae

File tree

10 files changed

+68
-16
lines changed

10 files changed

+68
-16
lines changed

docs/changelogs/v3.4.3.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
* adjust default hyperparameters the same as the original implementation.
1414
* support adjusted lr from the Moonlight. you can use it by setting `use_adjusted_lr=True`.
1515
* Tune the performance of the coupled Newton iteration method by 5% increase. (#360)
16+
* Update `SCION` optimizer. (#361)
17+
* add `scale` parameter.
18+
* update `get_lmo_direction`.
1619

1720
### Fix
1821

docs/visualization.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,10 @@
350350

351351
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_StableAdamW.png)
352352

353+
### StableSPAM
354+
355+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_StableSPAM.png)
356+
353357
### SWATS
354358

355359
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rastrigin_SWATS.png)
@@ -716,6 +720,10 @@
716720

717721
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_StableAdamW.png)
718722

723+
### StableSPAM
724+
725+
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_StableSPAM.png)
726+
719727
### SWATS
720728

721729
![image](https://raw.githubusercontent.com/kozistr/pytorch_optimizer/main/docs/visualizations/rosenbrock_SWATS.png)
260 Bytes
Loading
634 KB
Loading
9.12 KB
Loading
142 KB
Loading

examples/visualize_optimizers.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,6 @@ def closure() -> float:
204204
parameters = list(model.parameters())
205205
optimizer_name: str = optimizer_class.__name__.lower()
206206

207-
# Special handling for optimizers with unique requirements
208207
if optimizer_name == 'ranger21':
209208
optimizer_config['num_iterations'] = num_iters
210209
elif optimizer_name == 'ranger25':
@@ -215,11 +214,12 @@ def closure() -> float:
215214
optimizer_config['projection_fn'] = lambda: l2_projection(parameters, max_norm=1)
216215
elif optimizer_name == 'bsam':
217216
optimizer_config['num_data'] = 1
217+
elif optimizer_name == 'scion':
218+
optimizer_config['scale'] = 50.0
218219

219-
if optimizer_name in OPTIMIZERS_MODEL_INPUT_NEEDED:
220-
optimizer = optimizer_class(model, **optimizer_config)
221-
else:
222-
optimizer = optimizer_class(parameters, **optimizer_config)
220+
optimizer = optimizer_class(
221+
model if optimizer_name in OPTIMIZERS_MODEL_INPUT_NEEDED else parameters, **optimizer_config
222+
)
223223

224224
steps = torch.zeros((2, num_iters + 1), dtype=torch.float32)
225225
steps[:, 0] = model.x.detach()
@@ -394,7 +394,7 @@ def execute_experiments(
394394
rstate=np.random.default_rng(seed),
395395
)
396396
except AllTrialsFailed:
397-
print(f'⚠️ {optimizer_name} failed to optimize {func.__name__}') # noqa: T201
397+
print(f'{optimizer_name} failed to optimize {func.__name__}') # noqa: T201
398398
continue
399399

400400
steps, _ = execute_steps(func, initial_state, optimizer_class, best_params.copy(), TESTING_OPTIMIZATION_STEPS)

pytorch_optimizer/optimizer/scion.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
from typing import Literal
23

34
import torch
@@ -18,30 +19,35 @@ class SCION(BaseOptimizer):
1819
:param momentum: float. momentum factor.
1920
:param constraint: bool. whether to use a constraint SCG or not.
2021
:param lmo_type: LMO_TYPE. supported LMO types.
22+
:param scale: float. based on the usage of the original intend, 50.0 is used for Transformer block, and 3000.0 is
23+
used for others (e.g. Embedding, LM head)
2124
:param weight_decay: float. weight decay (L2 penalty).
2225
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
2326
"""
2427

2528
def __init__(
2629
self,
2730
params: PARAMETERS,
28-
lr: float = 1e-4,
31+
lr: float = 1e-3,
2932
momentum: float = 0.1,
3033
constraint: bool = False,
3134
lmo_type: LMO_TYPE = 'spectral',
35+
scale: float = 1.0,
3236
weight_decay: float = 0.0,
3337
weight_decouple: bool = True,
3438
**kwargs,
3539
):
3640
self.validate_learning_rate(lr)
3741
self.validate_range(momentum, 'momentum', 0.0, 1.0, '(]')
42+
self.validate_positive(scale, 'scale')
3843
self.validate_options(lmo_type, 'lmo_type', ['spectral', 'sign', 'col_norm', 'row_norm'])
3944

4045
defaults: DEFAULTS = {
4146
'lr': lr,
4247
'momentum': momentum,
4348
'constraint': constraint,
4449
'lmo_type': lmo_type,
50+
'scale': scale,
4551
'weight_decay': weight_decay,
4652
'weight_decouple': weight_decouple,
4753
}
@@ -58,17 +64,26 @@ def reset(self):
5864
state['d'] = torch.zeros_like(p)
5965

6066
@staticmethod
61-
def get_lmo_direction(grad: torch.Tensor, lmo_type: str) -> torch.Tensor:
62-
r"""Get LMO direction."""
63-
if lmo_type == 'spectral' and grad.ndim == 2:
64-
return zero_power_via_newton_schulz_5(grad)
67+
def get_lmo_direction(grad: torch.Tensor, lmo_type: LMO_TYPE) -> torch.Tensor:
68+
r"""Get LMO direction.
69+
70+
fallback to `sign`
71+
"""
72+
d_out, d_in, *_ = grad.shape if grad.ndim > 1 else (grad.size(0), grad.size(0))
73+
74+
if lmo_type == 'spectral':
75+
return (
76+
zero_power_via_newton_schulz_5(grad.reshape(len(grad), -1))
77+
.view(grad.shape)
78+
.mul_(max(1.0, math.sqrt(d_out / d_in)))
79+
)
6580
if lmo_type == 'sign':
66-
return torch.sign(grad)
81+
return torch.sign(grad).div_(d_in)
6782
if lmo_type == 'col_norm':
6883
return grad / torch.norm(grad, dim=0, keepdim=True).add_(1e-6)
6984
if lmo_type == 'row_norm' and grad.ndim == 2:
7085
return grad / torch.norm(grad, dim=1, keepdim=True).add_(1e-6)
71-
return torch.sign(grad)
86+
return torch.sign(grad).div_(d_in)
7287

7388
@torch.no_grad()
7489
def step(self, closure: CLOSURE = None) -> LOSS:
@@ -89,12 +104,13 @@ def step(self, closure: CLOSURE = None) -> LOSS:
89104

90105
state = self.state[p]
91106
if 'd' not in state:
92-
state['d'] = torch.zeros_like(p)
107+
state['d'] = torch.zeros_like(grad)
93108

94109
d = state['d']
95110
d.mul_(1.0 - group['momentum']).add_(grad, alpha=group['momentum'])
96111

97112
update = self.get_lmo_direction(d, group['lmo_type'])
113+
update.mul_(group['scale'])
98114

99115
if not group['constraint']:
100116
self.apply_weight_decay(

tests/constants.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -565,8 +565,8 @@
565565
(FOCUS, {'lr': 1e-1, 'weight_decay': 1e-3}, 5),
566566
(Kron, {'lr': 1e0, 'weight_decay': 1e-3}, 3),
567567
(EXAdam, {'lr': 1e-1, 'weight_decay': 1e-3}, 5),
568-
(SCION, {'lr': 5e-1, 'constraint': False, 'weight_decay': 1e-3}, 10),
569-
(SCION, {'lr': 1e-1, 'constraint': True}, 10),
568+
(SCION, {'lr': 5e-1, 'constraint': False, 'weight_decay': 1e-3}, 5),
569+
(SCION, {'lr': 1e-1, 'constraint': True, 'lmo_type': 'col_norm'}, 10),
570570
(Ranger25, {'lr': 1e-1}, 3),
571571
(Ranger25, {'lr': 1e-1, 't_alpha_beta3': 5}, 3),
572572
(Ranger25, {'lr': 5e-2, 'stable_adamw': False, 'orthograd': False, 'eps': None, 'lookahead_merge_time': 2}, 3),

tests/test_optimizers.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,6 +981,31 @@ def test_kron_optimizer():
981981
optimizer.step()
982982

983983

984+
def test_scion_lmo_types():
985+
grad = torch.ones(2, 2)
986+
987+
expected = torch.FloatTensor([[0.3438, 0.3438], [0.3438, 0.3438]]).bfloat16()
988+
actual = load_optimizer('scion').get_lmo_direction(grad, 'spectral')
989+
990+
torch.testing.assert_close(expected, actual, rtol=1e-5, atol=1e-5)
991+
992+
expected = torch.FloatTensor([[0.5, 0.5], [0.5, 0.5]])
993+
actual = load_optimizer('scion').get_lmo_direction(grad, 'sign')
994+
torch.testing.assert_close(actual, expected, rtol=1e-5, atol=1e-5)
995+
996+
expected = torch.FloatTensor([[0.7071, 0.7071], [0.7071, 0.7071]])
997+
actual = load_optimizer('scion').get_lmo_direction(grad, 'row_norm')
998+
torch.testing.assert_close(actual, expected, rtol=1e-5, atol=1e-5)
999+
1000+
expected = torch.FloatTensor([[0.7071, 0.7071], [0.7071, 0.7071]])
1001+
actual = load_optimizer('scion').get_lmo_direction(grad, 'col_norm')
1002+
torch.testing.assert_close(actual, expected, rtol=1e-5, atol=1e-5)
1003+
1004+
expected = torch.FloatTensor([[0.5, 0.5], [0.5, 0.5]])
1005+
actual = load_optimizer('scion').get_lmo_direction(grad, 'asdf')
1006+
torch.testing.assert_close(actual, expected, rtol=1e-5, atol=1e-5)
1007+
1008+
9841009
def test_schedulefree_wrapper():
9851010
model = Example()
9861011

0 commit comments

Comments
 (0)