diff --git a/pina/solver/physics_informed_solver/competitive_pinn.py b/pina/solver/physics_informed_solver/competitive_pinn.py index 058c53f40..18a9861e9 100644 --- a/pina/solver/physics_informed_solver/competitive_pinn.py +++ b/pina/solver/physics_informed_solver/competitive_pinn.py @@ -103,9 +103,6 @@ def __init__( loss=loss, ) - # Set automatic optimization to False - self.automatic_optimization = False - def forward(self, x): """ Forward pass. diff --git a/pina/solver/physics_informed_solver/self_adaptive_pinn.py b/pina/solver/physics_informed_solver/self_adaptive_pinn.py index a6310d515..d9c9a9bb3 100644 --- a/pina/solver/physics_informed_solver/self_adaptive_pinn.py +++ b/pina/solver/physics_informed_solver/self_adaptive_pinn.py @@ -158,9 +158,6 @@ def __init__( loss=loss, ) - # Set automatic optimization to False - self.automatic_optimization = False - self._vectorial_loss = deepcopy(self.loss) self._vectorial_loss.reduction = "none" diff --git a/pina/solver/solver.py b/pina/solver/solver.py index 2a173b33d..6776fea9d 100644 --- a/pina/solver/solver.py +++ b/pina/solver/solver.py @@ -14,9 +14,13 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): """ - Abstract base class for PINA solvers. All specific solvers should inherit - from this interface. This class is a wrapper of - :class:`~lightning.pytorch.LightningModule`. + Abstract base class for PINA solvers. All specific solvers must inherit + from this interface. This class extends + :class:`~lightning.pytorch.core.LightningModule`, providing additional + functionalities for defining and optimizing Deep Learning models. + + By inheriting from this base class, solvers gain access to built-in training + loops, logging utilities, and optimization techniques. """ def __init__(self, problem, weighting, use_lt): @@ -442,6 +446,14 @@ def __init__( :param bool use_lt: If ``True``, the solver uses LabelTensors as input. :raises ValueError: If the models are not a list or tuple with length greater than one. + + .. warning:: + :class:`MultiSolverInterface` uses manual optimization by setting + ``automatic_optimization=False`` in + :class:`~lightning.pytorch.core.LightningModule`. For more + information on manual optimization please + see `here `_. """ if not isinstance(models, (list, tuple)) or len(models) < 2: raise ValueError( @@ -450,6 +462,16 @@ def __init__( "one." ) + if optimizers is None: + optimizers = [ + self.default_torch_optimizer() for _ in range(len(models)) + ] + + if schedulers is None: + schedulers = [ + self.default_torch_scheduler() for _ in range(len(models)) + ] + if any(opt is None for opt in optimizers): optimizers = [ self.default_torch_optimizer() if opt is None else opt @@ -480,12 +502,23 @@ def __init__( f"Got {len(models)} models, and {len(optimizers)}" " optimizers." ) + if len(schedulers) != len(optimizers): + raise ValueError( + "You must define one scheduler for each optimizer." + f"Got {len(schedulers)} schedulers, and {len(optimizers)}" + " optimizers." + ) # initialize the model self._pina_models = torch.nn.ModuleList(models) self._pina_optimizers = optimizers self._pina_schedulers = schedulers + # Set automatic optimization to False. + # For more information on manual optimization see: + # http://lightning.ai/docs/pytorch/stable/model/manual_optimization.html + self.automatic_optimization = False + def configure_optimizers(self): """ Optimizer configuration for the solver.