Skip to content

Commit f1074de

Browse files
authored
Merge pull request #360 from kozistr/update/codes
[Update] lots of things
2 parents c950609 + 34584e8 commit f1074de

File tree

13 files changed

+314
-46
lines changed

13 files changed

+314
-46
lines changed

docs/changelogs/v3.4.3.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44

55
* Support `StableSPAM` optimizer. (#358, #359)
66
* [How to Train in 4-Bit More Stably than 16-Bit Adam](https://arxiv.org/abs/2502.17055?)
7+
* Support `ScheduleFreeWrapper`. (#334, #360)
78

89
### Update
910

1011
* Update Muon optimizer. (#355, #356)
1112
* support decoupled weight decay.
12-
* adjust default hyperparameters same with the original implementation.
13+
* adjust default hyperparameters the same as the original implementation.
1314
* support adjusted lr from the Moonlight. you can use it by setting `use_adjusted_lr=True`.
15+
* Tune the performance of the coupled Newton iteration method by 5% increase. (#360)
1416

1517
### Fix
1618

docs/optimizer.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,10 @@
336336
:docstring:
337337
:members:
338338

339+
::: pytorch_optimizer.ScheduleFreeWrapper
340+
:docstring:
341+
:members:
342+
339343
::: pytorch_optimizer.SCION
340344
:docstring:
341345
:members:

poetry.lock

Lines changed: 25 additions & 25 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pytorch_optimizer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@
139139
ScheduleFreeAdamW,
140140
ScheduleFreeRAdam,
141141
ScheduleFreeSGD,
142+
ScheduleFreeWrapper,
142143
SGDSaI,
143144
Shampoo,
144145
SignSGD,

pytorch_optimizer/optimizer/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,12 @@
8080
from pytorch_optimizer.optimizer.ranger21 import Ranger21
8181
from pytorch_optimizer.optimizer.rotograd import RotoGrad
8282
from pytorch_optimizer.optimizer.sam import BSAM, GSAM, SAM, WSAM, LookSAM
83-
from pytorch_optimizer.optimizer.schedulefree import ScheduleFreeAdamW, ScheduleFreeRAdam, ScheduleFreeSGD
83+
from pytorch_optimizer.optimizer.schedulefree import (
84+
ScheduleFreeAdamW,
85+
ScheduleFreeRAdam,
86+
ScheduleFreeSGD,
87+
ScheduleFreeWrapper,
88+
)
8489
from pytorch_optimizer.optimizer.scion import SCION
8590
from pytorch_optimizer.optimizer.sgd import ASGD, SGDW, AccSGD, SGDSaI, SignSGD
8691
from pytorch_optimizer.optimizer.sgdp import SGDP

pytorch_optimizer/optimizer/schedulefree.py

Lines changed: 198 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
from typing import List
1+
from collections import defaultdict
2+
from typing import Callable, Dict, List
23

34
import torch
5+
from torch.optim import Optimizer
46

57
from pytorch_optimizer.base.exception import NoSparseGradientError
68
from pytorch_optimizer.base.optimizer import BaseOptimizer
7-
from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
9+
from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, OPTIMIZER_INSTANCE_OR_CLASS, PARAMETERS, STATE
810

911

1012
class ScheduleFreeSGD(BaseOptimizer):
@@ -454,3 +456,197 @@ def step(self, closure: CLOSURE = None) -> LOSS:
454456
z.sub_(grad, alpha=lr)
455457

456458
return loss
459+
460+
461+
class ScheduleFreeWrapper(BaseOptimizer):
462+
r"""Wrap any optimizer to make it Schedule-Free.
463+
464+
This version uses a memory-efficient swap operation but may be slower than the reference version. In most cases
465+
the performance difference is negligible. For the best possible performance and memory-usage, Schedule-Free
466+
needs to be directly integrated with the base optimizer.
467+
468+
When using this version, you can disable the base optimizer's momentum, as it's no longer necessary when using
469+
our wrapper's momentum (although you can use both types of momentum if you want).
470+
471+
If you set weight decay on the base optimizer, it computes weight decay at $z$. We offer the option to compute
472+
weight decay at $y$, via the `weight_decay_at_y` parameter, which seems to give better results in our
473+
experiments. This approach to decay only works correctly if the base optimizer uses group['lr'] as the current
474+
learning rate.
475+
476+
:param optimizer: OPTIMIZER_INSTANCE_OR_CLASS. base optimizer.
477+
:param momentum: float. momentum.
478+
:param weight_decay: float. weight decay (L2 penalty).
479+
:param r: float. use polynomial weighting in the average with power r.
480+
:param weight_lr_power: float. during warmup, the weights in the average will be equal to lr raised to this power.
481+
set to 0 for no weighting.
482+
"""
483+
484+
def __init__(
485+
self,
486+
optimizer: OPTIMIZER_INSTANCE_OR_CLASS,
487+
momentum: float = 0.9,
488+
weight_decay: float = 0.0,
489+
r: float = 0.0,
490+
weight_lr_power: float = 2.0,
491+
**kwargs,
492+
):
493+
self.validate_range(momentum, 'momentum', 0.0, 1.0, '[)')
494+
self.validate_non_negative(weight_decay, 'weight_decay')
495+
496+
self.momentum = momentum
497+
self.weight_decay = weight_decay
498+
self.r = r
499+
self.weight_lr_power = weight_lr_power
500+
self.train_mode: bool = False
501+
502+
self.optimizer: Optimizer = self.load_optimizer(optimizer, **kwargs)
503+
504+
self._optimizer_step_pre_hooks: Dict[int, Callable] = {}
505+
self._optimizer_step_post_hooks: Dict[int, Callable] = {}
506+
507+
self.state: STATE = defaultdict(dict)
508+
509+
for group in self.param_groups:
510+
for p in group['params']:
511+
state = self.state[p]
512+
state['z'] = torch.clone(p)
513+
514+
self.defaults = self.optimizer.defaults
515+
516+
def __str__(self) -> str:
517+
return 'ScheduleFree'
518+
519+
@property
520+
def param_groups(self):
521+
return self.optimizer.param_groups
522+
523+
def __getstate__(self):
524+
return {'state': self.state, 'optimizer': self.optimizer}
525+
526+
def add_param_group(self, param_group):
527+
return self.optimizer.add_param_group(param_group)
528+
529+
def state_dict(self) -> STATE:
530+
return {'schedulefree_state': self.state, 'base_optimizer': self.optimizer.state_dict()}
531+
532+
def load_state_dict(self, state: STATE) -> None:
533+
r"""Load state."""
534+
self.state = state['schedulefree_state']
535+
self.optimizer.load_state_dict(state['base_optimizer'])
536+
537+
def zero_grad(self, set_to_none: bool = True) -> None:
538+
self.optimizer.zero_grad(set_to_none)
539+
540+
@torch.no_grad()
541+
def eval(self):
542+
if not self.train_mode:
543+
return
544+
545+
for group in self.param_groups:
546+
for p in group['params']:
547+
state = self.state[p]
548+
if 'z' in state:
549+
p.lerp_(end=state['z'], weight=1.0 - 1.0 / self.momentum)
550+
551+
self.train_mode = False
552+
553+
@torch.no_grad()
554+
def train(self):
555+
if self.train_mode:
556+
return
557+
558+
for group in self.param_groups:
559+
for p in group['params']:
560+
state = self.state[p]
561+
if 'z' in state:
562+
p.lerp_(end=state['z'], weight=1.0 - self.momentum)
563+
564+
self.train_mode = True
565+
566+
@torch.no_grad()
567+
def reset(self):
568+
pass
569+
570+
@staticmethod
571+
def swap(x: torch.Tensor, y: torch.Tensor) -> None:
572+
x.view(torch.uint8).bitwise_xor_(y.view(torch.uint8))
573+
y.view(torch.uint8).bitwise_xor_(x.view(torch.uint8))
574+
x.view(torch.uint8).bitwise_xor_(y.view(torch.uint8))
575+
576+
@torch.no_grad()
577+
def step(self, closure: CLOSURE = None) -> LOSS:
578+
if not self.train_mode:
579+
raise ValueError('optimizer was not in train mode when step is called. call .train() before training')
580+
581+
loss: LOSS = None
582+
if closure is not None:
583+
with torch.enable_grad():
584+
loss = closure()
585+
586+
for group in self.param_groups:
587+
for p in group['params']:
588+
if p.grad is None:
589+
continue
590+
591+
grad = p.grad
592+
if grad.is_sparse:
593+
raise NoSparseGradientError(str(self))
594+
595+
state = self.state[p]
596+
597+
z = state['z']
598+
599+
self.apply_weight_decay(
600+
z,
601+
grad,
602+
lr=group['lr'],
603+
weight_decay=self.weight_decay,
604+
weight_decouple=True,
605+
fixed_decay=False,
606+
)
607+
608+
self.apply_weight_decay(
609+
p,
610+
grad,
611+
lr=group['lr'],
612+
weight_decay=self.weight_decay,
613+
weight_decouple=True,
614+
fixed_decay=False,
615+
ratio=1.0 - self.momentum,
616+
)
617+
618+
p.lerp_(end=z, weight=1.0 - 1.0 / self.momentum)
619+
620+
self.swap(z, p)
621+
622+
self.optimizer.step()
623+
624+
for group in self.param_groups:
625+
if 'step' in group:
626+
group['step'] += 1
627+
else:
628+
group['step'] = 1
629+
630+
lr: float = group['lr'] * group.get('d', 1.0)
631+
lr_max = group['lr_max'] = max(lr, group.get('lr_max', 0))
632+
633+
weight: float = (group['step'] ** group['lr']) * (lr_max ** self.weight_lr_power) # fmt: skip
634+
weight_sum = group['weight_sum'] = group.get('weight_sum', 0.0) + weight
635+
636+
ckeckpoint: float = weight / weight_sum if weight_sum != 0.0 else 0.0
637+
638+
for p in group['params']:
639+
if p.grad is None:
640+
continue
641+
642+
state = self.state[p]
643+
644+
z = state['z']
645+
646+
self.swap(z, p)
647+
648+
p.lerp_(end=z, weight=ckeckpoint)
649+
650+
p.lerp_(end=state['z'], weight=1.0 - self.momentum)
651+
652+
return loss

pytorch_optimizer/optimizer/shampoo_utils.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ def power_iteration(mat_g: torch.Tensor, num_iters: int = 100) -> torch.Tensor:
412412
return (v.t() @ mat_g @ v).clamp_min_(1e-16)
413413

414414

415-
@torch.no_grad()
415+
@torch.inference_mode()
416416
def compute_power_schur_newton(
417417
mat_g: torch.Tensor,
418418
p: int,
@@ -465,24 +465,27 @@ def compute_power_schur_newton(
465465
alpha: float = -1.0 / p
466466
alpha_identity = (1.0 - alpha) * identity
467467

468-
prev_error = torch.max(torch.abs(mat_m - identity))
468+
prev_error = torch.dist(mat_m, identity, p=torch.inf)
469+
470+
mat_m_i = torch.empty_like(mat_m)
471+
new_mat_root = torch.empty_like(mat_root)
469472

470473
for _ in range(max_iters):
471-
mat_m_i = alpha_identity + alpha * mat_m
474+
torch.add(alpha_identity, alpha * mat_m, out=mat_m_i)
475+
torch.matmul(mat_root, mat_m_i, out=new_mat_root)
472476

473-
new_mat_root = torch.matmul(mat_root, mat_m_i)
474477
torch.matmul(torch.linalg.matrix_power(mat_m_i, p), mat_m, out=mat_m)
475478

476-
error = torch.max(torch.abs(mat_m - identity))
479+
error = torch.dist(mat_m, identity, p=torch.inf)
477480

478481
# NOTE
479-
# this is the main bottleneck that makes Scalable Shampoo slow.
480-
# because it is handled on the Python side so values need to be on the CPU
481-
# while XLA devices (e.g. TPU) doesn't seem to be affected.
482+
# This is the main bottleneck that slows Scalable Shampoo.
483+
# Because it is handled on the Python side so values need to be on the CPU
484+
# while XLA devices (e.g. TPU) don't seem to be affected.
482485
if torch.logical_or(error > prev_error * max_error_ratio, error <= error_tolerance):
483486
break
484487

485-
mat_root = new_mat_root
488+
mat_root.copy_(new_mat_root)
486489
prev_error = error
487490

488491
return mat_root

0 commit comments

Comments
 (0)