|
7 | 7 | from collections.abc import Iterable
|
8 | 8 | from copy import deepcopy
|
9 | 9 | from typing import Any, Callable, cast, Literal, Optional, Union
|
| 10 | +from typing_extensions import override |
10 | 11 |
|
11 | 12 | import torch
|
12 | 13 | from torch import Tensor
|
@@ -431,10 +432,7 @@ def __init__(
|
431 | 432 | "anneal_strategy must by one of 'cos' or 'linear', "
|
432 | 433 | f"instead got {anneal_strategy}"
|
433 | 434 | )
|
434 |
| - elif anneal_strategy == "cos": |
435 |
| - self.anneal_func = self._cosine_anneal |
436 |
| - elif anneal_strategy == "linear": |
437 |
| - self.anneal_func = self._linear_anneal |
| 435 | + self._set_anneal_func(anneal_strategy) |
438 | 436 | if not isinstance(anneal_epochs, int) or anneal_epochs < 0:
|
439 | 437 | raise ValueError(
|
440 | 438 | f"anneal_epochs must be equal or greater than 0, got {anneal_epochs}"
|
@@ -482,3 +480,34 @@ def get_lr(self):
|
482 | 480 | group["swa_lr"] * alpha + lr * (1 - alpha)
|
483 | 481 | for group, lr in zip(self.optimizer.param_groups, prev_lrs)
|
484 | 482 | ]
|
| 483 | + |
| 484 | + def _set_anneal_func(self, anneal_strategy: Literal["cos", "linear"]): |
| 485 | + self._anneal_strategy = anneal_strategy |
| 486 | + if anneal_strategy == "cos": |
| 487 | + self.anneal_func = self._cosine_anneal |
| 488 | + else: |
| 489 | + self.anneal_func = self._linear_anneal |
| 490 | + |
| 491 | + @override |
| 492 | + def state_dict(self) -> dict[str, Any]: |
| 493 | + """Return the state of the scheduler as a :class:`dict`. |
| 494 | +
|
| 495 | + It contains an entry for every variable in self.__dict__ which |
| 496 | + is not the optimizer or anneal_func. |
| 497 | + """ |
| 498 | + return { |
| 499 | + key: value |
| 500 | + for key, value in self.__dict__.items() |
| 501 | + if key not in ("optimizer", "anneal_func") |
| 502 | + } |
| 503 | + |
| 504 | + @override |
| 505 | + def load_state_dict(self, state_dict: dict[str, Any]) -> None: |
| 506 | + """Load the scheduler's state. |
| 507 | +
|
| 508 | + Args: |
| 509 | + state_dict (dict): scheduler state. Should be an object returned |
| 510 | + from a call to :meth:`state_dict`. |
| 511 | + """ |
| 512 | + self.__dict__.update(state_dict) |
| 513 | + self._set_anneal_func(self._anneal_strategy) |
0 commit comments