|
1 | | -from typing import List |
| 1 | +from collections import defaultdict |
| 2 | +from typing import Callable, Dict, List |
2 | 3 |
|
3 | 4 | import torch |
| 5 | +from torch.optim import Optimizer |
4 | 6 |
|
5 | 7 | from pytorch_optimizer.base.exception import NoSparseGradientError |
6 | 8 | from pytorch_optimizer.base.optimizer import BaseOptimizer |
7 | | -from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS |
| 9 | +from pytorch_optimizer.base.type import BETAS, CLOSURE, DEFAULTS, LOSS, OPTIMIZER_INSTANCE_OR_CLASS, PARAMETERS, STATE |
8 | 10 |
|
9 | 11 |
|
10 | 12 | class ScheduleFreeSGD(BaseOptimizer): |
@@ -454,3 +456,197 @@ def step(self, closure: CLOSURE = None) -> LOSS: |
454 | 456 | z.sub_(grad, alpha=lr) |
455 | 457 |
|
456 | 458 | return loss |
| 459 | + |
| 460 | + |
| 461 | +class ScheduleFreeWrapper(BaseOptimizer): |
| 462 | + r"""Wrap any optimizer to make it Schedule-Free. |
| 463 | +
|
| 464 | + This version uses a memory-efficient swap operation but may be slower than the reference version. In most cases |
| 465 | + the performance difference is negligible. For the best possible performance and memory-usage, Schedule-Free |
| 466 | + needs to be directly integrated with the base optimizer. |
| 467 | +
|
| 468 | + When using this version, you can disable the base optimizer's momentum, as it's no longer necessary when using |
| 469 | + our wrapper's momentum (although you can use both types of momentum if you want). |
| 470 | +
|
| 471 | + If you set weight decay on the base optimizer, it computes weight decay at $z$. We offer the option to compute |
| 472 | + weight decay at $y$, via the `weight_decay_at_y` parameter, which seems to give better results in our |
| 473 | + experiments. This approach to decay only works correctly if the base optimizer uses group['lr'] as the current |
| 474 | + learning rate. |
| 475 | +
|
| 476 | + :param optimizer: OPTIMIZER_INSTANCE_OR_CLASS. base optimizer. |
| 477 | + :param momentum: float. momentum. |
| 478 | + :param weight_decay: float. weight decay (L2 penalty). |
| 479 | + :param r: float. use polynomial weighting in the average with power r. |
| 480 | + :param weight_lr_power: float. during warmup, the weights in the average will be equal to lr raised to this power. |
| 481 | + set to 0 for no weighting. |
| 482 | + """ |
| 483 | + |
| 484 | + def __init__( |
| 485 | + self, |
| 486 | + optimizer: OPTIMIZER_INSTANCE_OR_CLASS, |
| 487 | + momentum: float = 0.9, |
| 488 | + weight_decay: float = 0.0, |
| 489 | + r: float = 0.0, |
| 490 | + weight_lr_power: float = 2.0, |
| 491 | + **kwargs, |
| 492 | + ): |
| 493 | + self.validate_range(momentum, 'momentum', 0.0, 1.0, '[)') |
| 494 | + self.validate_non_negative(weight_decay, 'weight_decay') |
| 495 | + |
| 496 | + self.momentum = momentum |
| 497 | + self.weight_decay = weight_decay |
| 498 | + self.r = r |
| 499 | + self.weight_lr_power = weight_lr_power |
| 500 | + self.train_mode: bool = False |
| 501 | + |
| 502 | + self.optimizer: Optimizer = self.load_optimizer(optimizer, **kwargs) |
| 503 | + |
| 504 | + self._optimizer_step_pre_hooks: Dict[int, Callable] = {} |
| 505 | + self._optimizer_step_post_hooks: Dict[int, Callable] = {} |
| 506 | + |
| 507 | + self.state: STATE = defaultdict(dict) |
| 508 | + |
| 509 | + for group in self.param_groups: |
| 510 | + for p in group['params']: |
| 511 | + state = self.state[p] |
| 512 | + state['z'] = torch.clone(p) |
| 513 | + |
| 514 | + self.defaults = self.optimizer.defaults |
| 515 | + |
| 516 | + def __str__(self) -> str: |
| 517 | + return 'ScheduleFree' |
| 518 | + |
| 519 | + @property |
| 520 | + def param_groups(self): |
| 521 | + return self.optimizer.param_groups |
| 522 | + |
| 523 | + def __getstate__(self): |
| 524 | + return {'state': self.state, 'optimizer': self.optimizer} |
| 525 | + |
| 526 | + def add_param_group(self, param_group): |
| 527 | + return self.optimizer.add_param_group(param_group) |
| 528 | + |
| 529 | + def state_dict(self) -> STATE: |
| 530 | + return {'schedulefree_state': self.state, 'base_optimizer': self.optimizer.state_dict()} |
| 531 | + |
| 532 | + def load_state_dict(self, state: STATE) -> None: |
| 533 | + r"""Load state.""" |
| 534 | + self.state = state['schedulefree_state'] |
| 535 | + self.optimizer.load_state_dict(state['base_optimizer']) |
| 536 | + |
| 537 | + def zero_grad(self, set_to_none: bool = True) -> None: |
| 538 | + self.optimizer.zero_grad(set_to_none) |
| 539 | + |
| 540 | + @torch.no_grad() |
| 541 | + def eval(self): |
| 542 | + if not self.train_mode: |
| 543 | + return |
| 544 | + |
| 545 | + for group in self.param_groups: |
| 546 | + for p in group['params']: |
| 547 | + state = self.state[p] |
| 548 | + if 'z' in state: |
| 549 | + p.lerp_(end=state['z'], weight=1.0 - 1.0 / self.momentum) |
| 550 | + |
| 551 | + self.train_mode = False |
| 552 | + |
| 553 | + @torch.no_grad() |
| 554 | + def train(self): |
| 555 | + if self.train_mode: |
| 556 | + return |
| 557 | + |
| 558 | + for group in self.param_groups: |
| 559 | + for p in group['params']: |
| 560 | + state = self.state[p] |
| 561 | + if 'z' in state: |
| 562 | + p.lerp_(end=state['z'], weight=1.0 - self.momentum) |
| 563 | + |
| 564 | + self.train_mode = True |
| 565 | + |
| 566 | + @torch.no_grad() |
| 567 | + def reset(self): |
| 568 | + pass |
| 569 | + |
| 570 | + @staticmethod |
| 571 | + def swap(x: torch.Tensor, y: torch.Tensor) -> None: |
| 572 | + x.view(torch.uint8).bitwise_xor_(y.view(torch.uint8)) |
| 573 | + y.view(torch.uint8).bitwise_xor_(x.view(torch.uint8)) |
| 574 | + x.view(torch.uint8).bitwise_xor_(y.view(torch.uint8)) |
| 575 | + |
| 576 | + @torch.no_grad() |
| 577 | + def step(self, closure: CLOSURE = None) -> LOSS: |
| 578 | + if not self.train_mode: |
| 579 | + raise ValueError('optimizer was not in train mode when step is called. call .train() before training') |
| 580 | + |
| 581 | + loss: LOSS = None |
| 582 | + if closure is not None: |
| 583 | + with torch.enable_grad(): |
| 584 | + loss = closure() |
| 585 | + |
| 586 | + for group in self.param_groups: |
| 587 | + for p in group['params']: |
| 588 | + if p.grad is None: |
| 589 | + continue |
| 590 | + |
| 591 | + grad = p.grad |
| 592 | + if grad.is_sparse: |
| 593 | + raise NoSparseGradientError(str(self)) |
| 594 | + |
| 595 | + state = self.state[p] |
| 596 | + |
| 597 | + z = state['z'] |
| 598 | + |
| 599 | + self.apply_weight_decay( |
| 600 | + z, |
| 601 | + grad, |
| 602 | + lr=group['lr'], |
| 603 | + weight_decay=self.weight_decay, |
| 604 | + weight_decouple=True, |
| 605 | + fixed_decay=False, |
| 606 | + ) |
| 607 | + |
| 608 | + self.apply_weight_decay( |
| 609 | + p, |
| 610 | + grad, |
| 611 | + lr=group['lr'], |
| 612 | + weight_decay=self.weight_decay, |
| 613 | + weight_decouple=True, |
| 614 | + fixed_decay=False, |
| 615 | + ratio=1.0 - self.momentum, |
| 616 | + ) |
| 617 | + |
| 618 | + p.lerp_(end=z, weight=1.0 - 1.0 / self.momentum) |
| 619 | + |
| 620 | + self.swap(z, p) |
| 621 | + |
| 622 | + self.optimizer.step() |
| 623 | + |
| 624 | + for group in self.param_groups: |
| 625 | + if 'step' in group: |
| 626 | + group['step'] += 1 |
| 627 | + else: |
| 628 | + group['step'] = 1 |
| 629 | + |
| 630 | + lr: float = group['lr'] * group.get('d', 1.0) |
| 631 | + lr_max = group['lr_max'] = max(lr, group.get('lr_max', 0)) |
| 632 | + |
| 633 | + weight: float = (group['step'] ** group['lr']) * (lr_max ** self.weight_lr_power) # fmt: skip |
| 634 | + weight_sum = group['weight_sum'] = group.get('weight_sum', 0.0) + weight |
| 635 | + |
| 636 | + ckeckpoint: float = weight / weight_sum if weight_sum != 0.0 else 0.0 |
| 637 | + |
| 638 | + for p in group['params']: |
| 639 | + if p.grad is None: |
| 640 | + continue |
| 641 | + |
| 642 | + state = self.state[p] |
| 643 | + |
| 644 | + z = state['z'] |
| 645 | + |
| 646 | + self.swap(z, p) |
| 647 | + |
| 648 | + p.lerp_(end=z, weight=ckeckpoint) |
| 649 | + |
| 650 | + p.lerp_(end=state['z'], weight=1.0 - self.momentum) |
| 651 | + |
| 652 | + return loss |
0 commit comments