Skip to content

Commit 1ed1491

Browse files
fix switch_optimizer callback
1 parent 6d10989 commit 1ed1491

File tree

2 files changed

+62
-37
lines changed

2 files changed

+62
-37
lines changed

pina/callback/optimizer_callback.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,45 +21,52 @@ def __init__(self, new_optimizers, epoch_switch):
2121
single :class:`torch.optim.Optimizer` instance or a list of them
2222
for multiple model solver.
2323
:type new_optimizers: pina.optim.TorchOptimizer | list
24-
:param epoch_switch: The epoch at which the optimizer switch occurs.
25-
:type epoch_switch: int
24+
:param int epoch_switch: The epoch at which the optimizer switch occurs.
2625
2726
Example:
28-
>>> switch_callback = SwitchOptimizer(new_optimizers=optimizer,
29-
>>> epoch_switch=10)
27+
>>> optimizer = TorchOptimizer(torch.optim.Adam, lr=0.01)
28+
>>> switch_callback = SwitchOptimizer(
29+
>>> new_optimizers=optimizer, epoch_switch=10
30+
>>> )
3031
"""
3132
super().__init__()
3233

34+
# Check if epoch_switch is greater than 1
3335
if epoch_switch < 1:
3436
raise ValueError("epoch_switch must be greater than one.")
3537

38+
# If new_optimizers is not a list, convert it to a list
3639
if not isinstance(new_optimizers, list):
3740
new_optimizers = [new_optimizers]
3841

39-
# check type consistency
42+
# Check consistency
43+
check_consistency(epoch_switch, int)
4044
for optimizer in new_optimizers:
4145
check_consistency(optimizer, TorchOptimizer)
42-
check_consistency(epoch_switch, int)
43-
# save new optimizers
46+
47+
# Store the new optimizers and epoch switch
4448
self._new_optimizers = new_optimizers
4549
self._epoch_switch = epoch_switch
4650

4751
def on_train_epoch_start(self, trainer, __):
4852
"""
4953
Switch the optimizer at the start of the specified training epoch.
5054
51-
:param trainer: The trainer object managing the training process.
52-
:type trainer: pytorch_lightning.Trainer
55+
:param lightning.pytorch.Trainer trainer: The trainer object managing
56+
the training process.
5357
:param _: Placeholder argument (not used).
54-
55-
:return: None
56-
:rtype: None
5758
"""
59+
# Check if the current epoch matches the switch epoch
5860
if trainer.current_epoch == self._epoch_switch:
5961
optims = []
6062

63+
# Hook the new optimizers to the model parameters
6164
for idx, optim in enumerate(self._new_optimizers):
6265
optim.hook(trainer.solver._pina_models[idx].parameters())
6366
optims.append(optim)
6467

68+
# Update the solver's optimizers
6569
trainer.solver._pina_optimizers = optims
70+
71+
# Update the trainer's strategy optimizers
72+
trainer.strategy.optimizers = [o.instance for o in optims]
Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,63 @@
1-
from pina.callback import SwitchOptimizer
21
import torch
32
import pytest
43

54
from pina.solver import PINN
65
from pina.trainer import Trainer
76
from pina.model import FeedForward
8-
from pina.problem.zoo import Poisson2DSquareProblem as Poisson
97
from pina.optim import TorchOptimizer
8+
from pina.callback import SwitchOptimizer
9+
from pina.problem.zoo import Poisson2DSquareProblem as Poisson
10+
11+
12+
# Define the problem
13+
problem = Poisson()
14+
problem.discretise_domain(10)
15+
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
16+
17+
# Define the optimizer
18+
optimizer = TorchOptimizer(torch.optim.Adam)
1019

11-
# make the problem
12-
poisson_problem = Poisson()
13-
boundaries = ["g1", "g2", "g3", "g4"]
14-
n = 10
15-
poisson_problem.discretise_domain(n, "grid", domains=boundaries)
16-
poisson_problem.discretise_domain(n, "grid", domains="D")
17-
model = FeedForward(
18-
len(poisson_problem.input_variables), len(poisson_problem.output_variables)
19-
)
20+
# Initialize the solver
21+
solver = PINN(problem=problem, model=model, optimizer=optimizer)
2022

21-
# make the solver
22-
solver = PINN(problem=poisson_problem, model=model)
23+
# Define new optimizers for testing
24+
lbfgs = TorchOptimizer(torch.optim.LBFGS, lr=1.0)
25+
adamW = TorchOptimizer(torch.optim.AdamW, lr=0.01)
2326

24-
adam = TorchOptimizer(torch.optim.Adam, lr=0.01)
25-
lbfgs = TorchOptimizer(torch.optim.LBFGS, lr=0.001)
2627

28+
@pytest.mark.parametrize("epoch_switch", [5, 10])
29+
@pytest.mark.parametrize("new_opt", [lbfgs, adamW])
30+
def test_switch_optimizer_constructor(new_opt, epoch_switch):
2731

28-
def test_switch_optimizer_constructor():
29-
SwitchOptimizer(adam, epoch_switch=10)
32+
# Constructor
33+
SwitchOptimizer(new_optimizers=new_opt, epoch_switch=epoch_switch)
3034

35+
# Should fail if epoch_switch is less than 1
36+
with pytest.raises(ValueError):
37+
SwitchOptimizer(new_optimizers=new_opt, epoch_switch=0)
3138

32-
def test_switch_optimizer_routine():
33-
# check initial optimizer
39+
40+
@pytest.mark.parametrize("epoch_switch", [5, 10])
41+
@pytest.mark.parametrize("new_opt", [lbfgs, adamW])
42+
def test_switch_optimizer_routine(new_opt, epoch_switch):
43+
44+
# Check if the optimizer is initialized correctly
3445
solver.configure_optimizers()
35-
assert solver.optimizer.instance.__class__ == torch.optim.Adam
36-
# make the trainer
37-
switch_opt_callback = SwitchOptimizer(lbfgs, epoch_switch=3)
46+
47+
# Initialize the trainer
48+
switch_opt_callback = SwitchOptimizer(
49+
new_optimizers=new_opt, epoch_switch=epoch_switch
50+
)
3851
trainer = Trainer(
3952
solver=solver,
40-
callbacks=[switch_opt_callback],
53+
callbacks=switch_opt_callback,
4154
accelerator="cpu",
42-
max_epochs=5,
55+
max_epochs=epoch_switch + 2,
4356
)
4457
trainer.train()
45-
assert solver.optimizer.instance.__class__ == torch.optim.LBFGS
58+
59+
# Check that the trainer strategy optimizers have been updated
60+
assert solver.optimizer.instance.__class__ == new_opt.instance.__class__
61+
assert (
62+
trainer.strategy.optimizers[0].__class__ == new_opt.instance.__class__
63+
)

0 commit comments

Comments
 (0)