|
1 | | -from pina.callback import SwitchOptimizer |
2 | 1 | import torch |
3 | 2 | import pytest |
4 | 3 |
|
5 | 4 | from pina.solver import PINN |
6 | 5 | from pina.trainer import Trainer |
7 | 6 | from pina.model import FeedForward |
8 | | -from pina.problem.zoo import Poisson2DSquareProblem as Poisson |
9 | 7 | from pina.optim import TorchOptimizer |
| 8 | +from pina.callback import SwitchOptimizer |
| 9 | +from pina.problem.zoo import Poisson2DSquareProblem as Poisson |
| 10 | + |
| 11 | + |
| 12 | +# Define the problem |
| 13 | +problem = Poisson() |
| 14 | +problem.discretise_domain(10) |
| 15 | +model = FeedForward(len(problem.input_variables), len(problem.output_variables)) |
| 16 | + |
| 17 | +# Define the optimizer |
| 18 | +optimizer = TorchOptimizer(torch.optim.Adam) |
10 | 19 |
|
11 | | -# make the problem |
12 | | -poisson_problem = Poisson() |
13 | | -boundaries = ["g1", "g2", "g3", "g4"] |
14 | | -n = 10 |
15 | | -poisson_problem.discretise_domain(n, "grid", domains=boundaries) |
16 | | -poisson_problem.discretise_domain(n, "grid", domains="D") |
17 | | -model = FeedForward( |
18 | | - len(poisson_problem.input_variables), len(poisson_problem.output_variables) |
19 | | -) |
| 20 | +# Initialize the solver |
| 21 | +solver = PINN(problem=problem, model=model, optimizer=optimizer) |
20 | 22 |
|
21 | | -# make the solver |
22 | | -solver = PINN(problem=poisson_problem, model=model) |
| 23 | +# Define new optimizers for testing |
| 24 | +lbfgs = TorchOptimizer(torch.optim.LBFGS, lr=1.0) |
| 25 | +adamW = TorchOptimizer(torch.optim.AdamW, lr=0.01) |
23 | 26 |
|
24 | | -adam = TorchOptimizer(torch.optim.Adam, lr=0.01) |
25 | | -lbfgs = TorchOptimizer(torch.optim.LBFGS, lr=0.001) |
26 | 27 |
|
| 28 | +@pytest.mark.parametrize("epoch_switch", [5, 10]) |
| 29 | +@pytest.mark.parametrize("new_opt", [lbfgs, adamW]) |
| 30 | +def test_switch_optimizer_constructor(new_opt, epoch_switch): |
27 | 31 |
|
28 | | -def test_switch_optimizer_constructor(): |
29 | | - SwitchOptimizer(adam, epoch_switch=10) |
| 32 | + # Constructor |
| 33 | + SwitchOptimizer(new_optimizers=new_opt, epoch_switch=epoch_switch) |
30 | 34 |
|
| 35 | + # Should fail if epoch_switch is less than 1 |
| 36 | + with pytest.raises(ValueError): |
| 37 | + SwitchOptimizer(new_optimizers=new_opt, epoch_switch=0) |
31 | 38 |
|
32 | | -def test_switch_optimizer_routine(): |
33 | | - # check initial optimizer |
| 39 | + |
| 40 | +@pytest.mark.parametrize("epoch_switch", [5, 10]) |
| 41 | +@pytest.mark.parametrize("new_opt", [lbfgs, adamW]) |
| 42 | +def test_switch_optimizer_routine(new_opt, epoch_switch): |
| 43 | + |
| 44 | + # Check if the optimizer is initialized correctly |
34 | 45 | solver.configure_optimizers() |
35 | | - assert solver.optimizer.instance.__class__ == torch.optim.Adam |
36 | | - # make the trainer |
37 | | - switch_opt_callback = SwitchOptimizer(lbfgs, epoch_switch=3) |
| 46 | + |
| 47 | + # Initialize the trainer |
| 48 | + switch_opt_callback = SwitchOptimizer( |
| 49 | + new_optimizers=new_opt, epoch_switch=epoch_switch |
| 50 | + ) |
38 | 51 | trainer = Trainer( |
39 | 52 | solver=solver, |
40 | | - callbacks=[switch_opt_callback], |
| 53 | + callbacks=switch_opt_callback, |
41 | 54 | accelerator="cpu", |
42 | | - max_epochs=5, |
| 55 | + max_epochs=epoch_switch + 2, |
43 | 56 | ) |
44 | 57 | trainer.train() |
45 | | - assert solver.optimizer.instance.__class__ == torch.optim.LBFGS |
| 58 | + |
| 59 | + # Check that the trainer strategy optimizers have been updated |
| 60 | + assert solver.optimizer.instance.__class__ == new_opt.instance.__class__ |
| 61 | + assert ( |
| 62 | + trainer.strategy.optimizers[0].__class__ == new_opt.instance.__class__ |
| 63 | + ) |
0 commit comments