Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions pina/solver/physics_informed_solver/competitive_pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,6 @@ def __init__(
loss=loss,
)

# Set automatic optimization to False
self.automatic_optimization = False

def forward(self, x):
"""
Forward pass.
Expand Down
3 changes: 0 additions & 3 deletions pina/solver/physics_informed_solver/self_adaptive_pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,6 @@ def __init__(
loss=loss,
)

# Set automatic optimization to False
self.automatic_optimization = False

self._vectorial_loss = deepcopy(self.loss)
self._vectorial_loss.reduction = "none"

Expand Down
39 changes: 36 additions & 3 deletions pina/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@

class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
"""
Abstract base class for PINA solvers. All specific solvers should inherit
from this interface. This class is a wrapper of
:class:`~lightning.pytorch.LightningModule`.
Abstract base class for PINA solvers. All specific solvers must inherit
from this interface. This class extends
:class:`~lightning.pytorch.core.LightningModule`, providing additional
functionalities for defining and optimizing Deep Learning models.

By inheriting from this base class, solvers gain access to built-in training
loops, logging utilities, and optimization techniques.
"""

def __init__(self, problem, weighting, use_lt):
Expand Down Expand Up @@ -442,6 +446,14 @@ def __init__(
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
:raises ValueError: If the models are not a list or tuple with length
greater than one.

.. warning::
:class:`MultiSolverInterface` uses manual optimization by setting
``automatic_optimization=False`` in
:class:`~lightning.pytorch.core.LightningModule`. For more
information on manual optimization please
see `here <https://lightning.ai/docs/pytorch/stable/\
model/manual_optimization.html>`_.
"""
if not isinstance(models, (list, tuple)) or len(models) < 2:
raise ValueError(
Expand All @@ -450,6 +462,16 @@ def __init__(
"one."
)

if optimizers is None:
optimizers = [
self.default_torch_optimizer() for _ in range(len(models))
]

if schedulers is None:
schedulers = [
self.default_torch_scheduler() for _ in range(len(models))
]

if any(opt is None for opt in optimizers):
optimizers = [
self.default_torch_optimizer() if opt is None else opt
Expand Down Expand Up @@ -480,12 +502,23 @@ def __init__(
f"Got {len(models)} models, and {len(optimizers)}"
" optimizers."
)
if len(schedulers) != len(optimizers):
raise ValueError(
"You must define one scheduler for each optimizer."
f"Got {len(schedulers)} schedulers, and {len(optimizers)}"
" optimizers."
)

# initialize the model
self._pina_models = torch.nn.ModuleList(models)
self._pina_optimizers = optimizers
self._pina_schedulers = schedulers

# Set automatic optimization to False.
# For more information on manual optimization see:
# http://lightning.ai/docs/pytorch/stable/model/manual_optimization.html
self.automatic_optimization = False

def configure_optimizers(self):
"""
Optimizer configuration for the solver.
Expand Down
Loading