Skip to content

Commit d86199f

Browse files
authored
Merge pull request #101 from KamitaniLab/fix-scheduler-behavior
Fix problematic behavior of optimizer/scheduler in FeatureInversionTask
2 parents 50cac33 + c23afe7 commit d86199f

File tree

5 files changed

+250
-40
lines changed

5 files changed

+250
-40
lines changed

bdpy/recon/torch/modules/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from .generator import build_generator, BaseGenerator
33
from .latent import ArbitraryLatent, BaseLatent
44
from .critic import TargetNormalizedMSE, BaseCritic
5+
from .optimizer import build_optimizer_factory, build_scheduler_factory
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
from functools import partial
5+
from itertools import chain
6+
7+
if TYPE_CHECKING:
8+
from typing import Dict, Any, Tuple, Union, Iterable, Callable
9+
from typing_extensions import TypeAlias
10+
from torch import Tensor
11+
import torch.optim as optim
12+
from ..modules import BaseGenerator, BaseLatent
13+
14+
# NOTE: The definition of `_ParamsT` is the same as in `torch.optim.optimizer`
15+
# in torch>=2.2.0. We define it here for compatibility with older versions.
16+
_ParamsT: TypeAlias = Union[
17+
Iterable[Tensor], Iterable[Dict[str, Any]], Iterable[Tuple[str, Tensor]]
18+
]
19+
20+
_OptimizerFactoryType: TypeAlias = Callable[
21+
[BaseGenerator, BaseLatent], optim.Optimizer
22+
]
23+
_SchedulerFactoryType: TypeAlias = Callable[
24+
[optim.Optimizer], optim.lr_scheduler.LRScheduler
25+
]
26+
_GetParamsFnType: TypeAlias = Callable[[BaseGenerator, BaseLatent], _ParamsT]
27+
28+
29+
def build_optimizer_factory(
30+
optimizer_class: type[optim.Optimizer],
31+
*,
32+
get_params_fn: _GetParamsFnType | None = None,
33+
**kwargs,
34+
) -> _OptimizerFactoryType:
35+
"""Build an optimizer factory.
36+
37+
Parameters
38+
----------
39+
optimizer_class : type
40+
Optimizer class.
41+
get_params_fn : Callable[[BaseGenerator, BaseLatent], _ParamsT] | None
42+
Custom function to get parameters from the generator and the latent.
43+
If None, it uses `chain(generator.parameters(), latent.parameters())`.
44+
kwargs : dict
45+
Keyword arguments for the optimizer.
46+
47+
Returns
48+
-------
49+
Callable[[BaseGenerator, BaseLatent], optim.Optimizer]
50+
Optimizer factory.
51+
52+
Examples
53+
--------
54+
>>> from torch.optim import Adam
55+
>>> from bdpy.recon.torch.modules import build_optimizer_factory
56+
>>> optimizer_factory = build_optimizer_factory(Adam, lr=1e-3)
57+
>>> optimizer = optimizer_factory(generator, latent)
58+
"""
59+
if get_params_fn is None:
60+
get_params_fn = lambda generator, latent: chain(
61+
generator.parameters(), latent.parameters()
62+
)
63+
64+
def init_fn(generator: BaseGenerator, latent: BaseLatent) -> optim.Optimizer:
65+
return optimizer_class(get_params_fn(generator, latent), **kwargs)
66+
67+
return init_fn
68+
69+
70+
def build_scheduler_factory(
71+
scheduler_class: type[optim.lr_scheduler.LRScheduler], **kwargs
72+
) -> _SchedulerFactoryType:
73+
"""Build a scheduler factory.
74+
75+
Parameters
76+
----------
77+
scheduler_class : type
78+
Scheduler class.
79+
kwargs : dict
80+
Keyword arguments for the scheduler.
81+
82+
Returns
83+
-------
84+
Callable[[optim.Optimizer], optim.lr_scheduler.LRScheduler]
85+
Scheduler factory.
86+
87+
Examples
88+
--------
89+
>>> from torch.optim.lr_scheduler import StepLR
90+
>>> from bdpy.recon.torch.modules import build_scheduler_factory
91+
>>> scheduler_factory = build_scheduler_factory(StepLR, step_size=100, gamma=0.1)
92+
>>> scheduler = scheduler_factory(optimizer)
93+
"""
94+
return partial(scheduler_class, **kwargs)

bdpy/recon/torch/task/inversion.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11
from __future__ import annotations
22

3-
from typing import Dict, Iterable, Callable
3+
from typing import Dict, Iterable, Callable, TYPE_CHECKING
44

55
from itertools import chain
66

7-
import torch
8-
9-
from ..modules import BaseEncoder, BaseGenerator, BaseLatent, BaseCritic
107
from bdpy.task import BaseTask
118
from bdpy.task.callback import BaseCallback, unused, _validate_callback
129

13-
FeatureType = Dict[str, torch.Tensor]
10+
if TYPE_CHECKING:
11+
import torch
12+
13+
from ..modules import BaseEncoder, BaseGenerator, BaseLatent, BaseCritic
14+
from ..modules.optimizer import _OptimizerFactoryType, _SchedulerFactoryType
15+
16+
_FeatureType = Dict[str, torch.Tensor]
1417

1518

1619
def _apply_to_features(
17-
fn: Callable[[torch.Tensor], torch.Tensor], features: FeatureType
18-
) -> FeatureType:
20+
fn: Callable[[torch.Tensor], torch.Tensor], features: _FeatureType
21+
) -> _FeatureType:
1922
return {k: fn(v) for k, v in features.items()}
2023

2124

@@ -115,10 +118,10 @@ class FeatureInversionTask(BaseTask):
115118
Latent variable module.
116119
critic : BaseCritic
117120
Critic module.
118-
optimizer : torch.optim.Optimizer
119-
Optimizer.
120-
scheduler : torch.optim.lr_scheduler.LRScheduler, optional
121-
Learning rate scheduler, by default None.
121+
optimizer_factory : _OptimizerFactoryType
122+
Factory function for optimizer.
123+
scheduler_factory : _SchedulerFactoryType | None, optional
124+
Factory function for scheduler, by default None.
122125
num_iterations : int, optional
123126
Number of iterations, by default 1.
124127
callbacks : FeatureInversionCallback | Iterable[FeatureInversionCallback] | None, optional
@@ -135,9 +138,9 @@ class FeatureInversionTask(BaseTask):
135138
>>> generator = build_generator(...)
136139
>>> latent = ArbitraryLatent(...)
137140
>>> critic = TargetNormalizedMSE(...)
138-
>>> optimizer = torch.optim.Adam(latent.parameters())
141+
>>> optimizer_factory = build_optimizer_factory(...)
139142
>>> task = FeatureInversionTask(
140-
... encoder, generator, latent, critic, optimizer, num_iterations=200,
143+
... encoder, generator, latent, critic, optimizer_factory, num_iterations=200,
141144
... )
142145
>>> target_features = encoder(target_image)
143146
>>> reconstructed_image = task(target_features)
@@ -149,8 +152,8 @@ def __init__(
149152
generator: BaseGenerator,
150153
latent: BaseLatent,
151154
critic: BaseCritic,
152-
optimizer: torch.optim.Optimizer,
153-
scheduler: torch.optim.lr_scheduler.LRScheduler = None,
155+
optimizer_factory: _OptimizerFactoryType,
156+
scheduler_factory: _SchedulerFactoryType | None = None,
154157
num_iterations: int = 1,
155158
callbacks: FeatureInversionCallback
156159
| Iterable[FeatureInversionCallback]
@@ -161,14 +164,14 @@ def __init__(
161164
self._generator = generator
162165
self._latent = latent
163166
self._critic = critic
164-
self._optimizer = optimizer
165-
self._scheduler = scheduler
167+
self._optimizer_factory = optimizer_factory
168+
self._scheduler_factory = scheduler_factory
166169

167170
self._num_iterations = num_iterations
168171

169172
def run(
170173
self,
171-
target_features: FeatureType,
174+
target_features: _FeatureType,
172175
) -> torch.Tensor:
173176
"""Run feature inversion given target features.
174177
@@ -217,10 +220,8 @@ def reset_states(self) -> None:
217220
"""Reset the state of the task."""
218221
self._generator.reset_states()
219222
self._latent.reset_states()
220-
self._optimizer = self._optimizer.__class__(
221-
chain(
222-
self._generator.parameters(),
223-
self._latent.parameters(),
224-
),
225-
**self._optimizer.defaults,
226-
)
223+
self._optimizer = self._optimizer_factory(self._generator, self._latent)
224+
if self._scheduler_factory is not None:
225+
self._scheduler = self._scheduler_factory(self._optimizer)
226+
else:
227+
self._scheduler = None
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
"""Tests for bdpy.recon.torch.modules.optimizer"""
2+
3+
from __future__ import annotations
4+
5+
import unittest
6+
7+
from functools import partial
8+
import numpy as np
9+
import torch.nn as nn
10+
import torch.optim as optim
11+
from bdpy.recon.torch.modules import build_generator, ArbitraryLatent
12+
from bdpy.recon.torch.modules import build_optimizer_factory, build_scheduler_factory
13+
14+
15+
class MLP(nn.Module):
16+
def __init__(self, in_dim, out_dim):
17+
super().__init__()
18+
self.fc = nn.Linear(in_dim, out_dim)
19+
20+
def forward(self, x):
21+
return self.fc(x)
22+
23+
24+
class TestBuildOptimizerFactory(unittest.TestCase):
25+
"""Tests for bdpy.recon.torch.modules.optimizer.build_optimizer_factory"""
26+
27+
def test_build_optimizer_factory(self):
28+
generator = build_generator(MLP(64, 10))
29+
latent = ArbitraryLatent(
30+
(1, 64), init_fn=partial(nn.init.normal_, mean=0, std=1)
31+
)
32+
optimizer_factory = build_optimizer_factory(optim.SGD, lr=0.1)
33+
optimizer = optimizer_factory(generator, latent)
34+
self.assertIsInstance(
35+
optimizer,
36+
optim.SGD,
37+
msg="optimizer_factory should return an instance of optim.Optimizer",
38+
)
39+
40+
latent.reset_states()
41+
generator.reset_states()
42+
latent_prev = latent().detach().clone().numpy()
43+
optimizer.zero_grad()
44+
output = generator(latent())
45+
loss = output.sum()
46+
loss.backward()
47+
latent_next_expected = (
48+
latent_prev - 0.1 * latent().grad.detach().clone().numpy()
49+
)
50+
optimizer.step()
51+
latent_next = latent().detach().clone().numpy()
52+
np.testing.assert_allclose(
53+
latent_next,
54+
latent_next_expected,
55+
rtol=1e-6,
56+
err_msg="Optimizer does not update the latent variable correctly.",
57+
)
58+
59+
# check if all the frozen generator's gradients are None
60+
generator_grad = [p.grad for p in generator.parameters()]
61+
self.assertTrue(
62+
all([g is None for g in generator_grad]),
63+
msg="Frozen generator's gradients should be None after the optimizer step.",
64+
)
65+
66+
67+
class TestBuildSchedulerFactory(unittest.TestCase):
68+
"""Tests for bdpy.recon.torch.modules.optimizer.build_scheduler_factory"""
69+
70+
def test_build_scheduler_factory(self):
71+
generator = build_generator(MLP(64, 10))
72+
latent = ArbitraryLatent(
73+
(1, 64), init_fn=partial(nn.init.normal_, mean=0, std=1)
74+
)
75+
optimizer_factory = build_optimizer_factory(optim.SGD, lr=0.1)
76+
scheduler_factory = build_scheduler_factory(
77+
optim.lr_scheduler.StepLR, step_size=1, gamma=0.1
78+
)
79+
optimizer = optimizer_factory(generator, latent)
80+
scheduler = scheduler_factory(optimizer)
81+
self.assertIsInstance(
82+
scheduler,
83+
optim.lr_scheduler.StepLR,
84+
msg="Scheduler factory should return an instance of optim.lr_scheduler.LRScheduler",
85+
)
86+
87+
latent.reset_states()
88+
generator.reset_states()
89+
optimizer.zero_grad()
90+
output = generator(latent())
91+
loss = output.sum()
92+
loss.backward()
93+
optimizer.step()
94+
scheduler.step()
95+
self.assertEqual(
96+
optimizer.param_groups[0]["lr"],
97+
0.1 * 0.1,
98+
"Scheduler does not update the learning rate correctly.",
99+
)
100+
101+
# check if reference to the optimizer is kept during re-initialization
102+
for _ in range(10):
103+
optimizer = optimizer_factory(generator, latent)
104+
scheduler = scheduler_factory(optimizer)
105+
else:
106+
self.assertTrue(
107+
scheduler.optimizer is optimizer,
108+
"Scheduler should keep the reference to the optimizer during re-initialization.",
109+
)
110+
111+
112+
if __name__ == "__main__":
113+
unittest.main()

0 commit comments

Comments
 (0)