Skip to content

Commit 167ad09

Browse files
filipvizpytorchmergebot
authored andcommitted
[optim] override SWALR.state_dict and load_state_dict (pytorch#163122)
Fixes pytorch#163105 Note that the new `SWALR.load_state_dict` is **not backwards compatible**: ```python @OverRide def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Load the scheduler's state. Args: state_dict (dict): scheduler state. Should be an object returned from a call to :meth:`state_dict`. """ self.__dict__.update(state_dict) self._set_anneal_func(self._anneal_strategy) ``` If we'd like to maintain compatibility with old state_dicts (loaded with `weights_only=False`), we could use something along these lines: ```python @OverRide def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Load the scheduler's state. Args: state_dict (dict): scheduler state. Should be an object returned from a call to :meth:`state_dict`. """ anneal_func = state_dict.pop("anneal_func", None) strategy = state_dict.get("_anneal_strategy") self.__dict__.update(state_dict) if anneal_func is not None: state_dict["anneal_func"] = anneal_func if strategy is None: if anneal_func == self._linear_anneal: strategy = "linear" elif anneal_func == self._cosine_anneal: strategy = "cos" if strategy is None: strategy = getattr(self, "_anneal_strategy", "cos") self._set_anneal_func(strategy) ``` But given the fact that loading an `SWALR` state_dict before this PR would have caused an error, this seems okay. A GitHub/Google search for `SWALR.load_state_dict` had no results. Happy to change if not, or add a warning just in case. Pull Request resolved: pytorch#163122 Approved by: https://github.com/janeyx99
1 parent bcbb45b commit 167ad09

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

test/optim/test_lrscheduler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2442,6 +2442,7 @@ def test_cosine_then_cyclic(self):
24422442
partial(CyclicLR, base_lr=0.01, max_lr=0.1),
24432443
partial(OneCycleLR, max_lr=0.01, total_steps=10, anneal_strategy="linear"),
24442444
partial(CosineAnnealingWarmRestarts, T_0=20),
2445+
partial(SWALR, swa_lr=0.01),
24452446
],
24462447
)
24472448
@parametrize("weights_only", [True, False])

torch/optim/swa_utils.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from collections.abc import Iterable
88
from copy import deepcopy
99
from typing import Any, Callable, cast, Literal, Optional, Union
10+
from typing_extensions import override
1011

1112
import torch
1213
from torch import Tensor
@@ -431,10 +432,7 @@ def __init__(
431432
"anneal_strategy must by one of 'cos' or 'linear', "
432433
f"instead got {anneal_strategy}"
433434
)
434-
elif anneal_strategy == "cos":
435-
self.anneal_func = self._cosine_anneal
436-
elif anneal_strategy == "linear":
437-
self.anneal_func = self._linear_anneal
435+
self._set_anneal_func(anneal_strategy)
438436
if not isinstance(anneal_epochs, int) or anneal_epochs < 0:
439437
raise ValueError(
440438
f"anneal_epochs must be equal or greater than 0, got {anneal_epochs}"
@@ -482,3 +480,34 @@ def get_lr(self):
482480
group["swa_lr"] * alpha + lr * (1 - alpha)
483481
for group, lr in zip(self.optimizer.param_groups, prev_lrs)
484482
]
483+
484+
def _set_anneal_func(self, anneal_strategy: Literal["cos", "linear"]):
485+
self._anneal_strategy = anneal_strategy
486+
if anneal_strategy == "cos":
487+
self.anneal_func = self._cosine_anneal
488+
else:
489+
self.anneal_func = self._linear_anneal
490+
491+
@override
492+
def state_dict(self) -> dict[str, Any]:
493+
"""Return the state of the scheduler as a :class:`dict`.
494+
495+
It contains an entry for every variable in self.__dict__ which
496+
is not the optimizer or anneal_func.
497+
"""
498+
return {
499+
key: value
500+
for key, value in self.__dict__.items()
501+
if key not in ("optimizer", "anneal_func")
502+
}
503+
504+
@override
505+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
506+
"""Load the scheduler's state.
507+
508+
Args:
509+
state_dict (dict): scheduler state. Should be an object returned
510+
from a call to :meth:`state_dict`.
511+
"""
512+
self.__dict__.update(state_dict)
513+
self._set_anneal_func(self._anneal_strategy)

0 commit comments

Comments
 (0)