Skip to content

Commit 0a60ed4

Browse files
authored
Update MultiSolverInterface (#520)
1 parent 4357f86 commit 0a60ed4

File tree

3 files changed

+36
-9
lines changed

3 files changed

+36
-9
lines changed

pina/solver/physics_informed_solver/competitive_pinn.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,6 @@ def __init__(
103103
loss=loss,
104104
)
105105

106-
# Set automatic optimization to False
107-
self.automatic_optimization = False
108-
109106
def forward(self, x):
110107
"""
111108
Forward pass.

pina/solver/physics_informed_solver/self_adaptive_pinn.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,6 @@ def __init__(
158158
loss=loss,
159159
)
160160

161-
# Set automatic optimization to False
162-
self.automatic_optimization = False
163-
164161
self._vectorial_loss = deepcopy(self.loss)
165162
self._vectorial_loss.reduction = "none"
166163

pina/solver/solver.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,13 @@
1414

1515
class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
1616
"""
17-
Abstract base class for PINA solvers. All specific solvers should inherit
18-
from this interface. This class is a wrapper of
19-
:class:`~lightning.pytorch.LightningModule`.
17+
Abstract base class for PINA solvers. All specific solvers must inherit
18+
from this interface. This class extends
19+
:class:`~lightning.pytorch.core.LightningModule`, providing additional
20+
functionalities for defining and optimizing Deep Learning models.
21+
22+
By inheriting from this base class, solvers gain access to built-in training
23+
loops, logging utilities, and optimization techniques.
2024
"""
2125

2226
def __init__(self, problem, weighting, use_lt):
@@ -442,6 +446,14 @@ def __init__(
442446
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
443447
:raises ValueError: If the models are not a list or tuple with length
444448
greater than one.
449+
450+
.. warning::
451+
:class:`MultiSolverInterface` uses manual optimization by setting
452+
``automatic_optimization=False`` in
453+
:class:`~lightning.pytorch.core.LightningModule`. For more
454+
information on manual optimization please
455+
see `here <https://lightning.ai/docs/pytorch/stable/\
456+
model/manual_optimization.html>`_.
445457
"""
446458
if not isinstance(models, (list, tuple)) or len(models) < 2:
447459
raise ValueError(
@@ -450,6 +462,16 @@ def __init__(
450462
"one."
451463
)
452464

465+
if optimizers is None:
466+
optimizers = [
467+
self.default_torch_optimizer() for _ in range(len(models))
468+
]
469+
470+
if schedulers is None:
471+
schedulers = [
472+
self.default_torch_scheduler() for _ in range(len(models))
473+
]
474+
453475
if any(opt is None for opt in optimizers):
454476
optimizers = [
455477
self.default_torch_optimizer() if opt is None else opt
@@ -480,12 +502,23 @@ def __init__(
480502
f"Got {len(models)} models, and {len(optimizers)}"
481503
" optimizers."
482504
)
505+
if len(schedulers) != len(optimizers):
506+
raise ValueError(
507+
"You must define one scheduler for each optimizer."
508+
f"Got {len(schedulers)} schedulers, and {len(optimizers)}"
509+
" optimizers."
510+
)
483511

484512
# initialize the model
485513
self._pina_models = torch.nn.ModuleList(models)
486514
self._pina_optimizers = optimizers
487515
self._pina_schedulers = schedulers
488516

517+
# Set automatic optimization to False.
518+
# For more information on manual optimization see:
519+
# http://lightning.ai/docs/pytorch/stable/model/manual_optimization.html
520+
self.automatic_optimization = False
521+
489522
def configure_optimizers(self):
490523
"""
491524
Optimizer configuration for the solver.

0 commit comments

Comments
 (0)