1414
1515class SolverInterface (lightning .pytorch .LightningModule , metaclass = ABCMeta ):
1616 """
17- Abstract base class for PINA solvers. All specific solvers should inherit
18- from this interface. This class is a wrapper of
19- :class:`~lightning.pytorch.LightningModule`.
17+ Abstract base class for PINA solvers. All specific solvers must inherit
18+ from this interface. This class extends
19+ :class:`~lightning.pytorch.core.LightningModule`, providing additional
20+ functionalities for defining and optimizing Deep Learning models.
21+
22+ By inheriting from this base class, solvers gain access to built-in training
23+ loops, logging utilities, and optimization techniques.
2024 """
2125
2226 def __init__ (self , problem , weighting , use_lt ):
@@ -442,6 +446,14 @@ def __init__(
442446 :param bool use_lt: If ``True``, the solver uses LabelTensors as input.
443447 :raises ValueError: If the models are not a list or tuple with length
444448 greater than one.
449+
450+ .. warning::
451+ :class:`MultiSolverInterface` uses manual optimization by setting
452+ ``automatic_optimization=False`` in
453+ :class:`~lightning.pytorch.core.LightningModule`. For more
454+ information on manual optimization please
455+ see `here <https://lightning.ai/docs/pytorch/stable/\
456+ model/manual_optimization.html>`_.
445457 """
446458 if not isinstance (models , (list , tuple )) or len (models ) < 2 :
447459 raise ValueError (
@@ -450,6 +462,16 @@ def __init__(
450462 "one."
451463 )
452464
465+ if optimizers is None :
466+ optimizers = [
467+ self .default_torch_optimizer () for _ in range (len (models ))
468+ ]
469+
470+ if schedulers is None :
471+ schedulers = [
472+ self .default_torch_scheduler () for _ in range (len (models ))
473+ ]
474+
453475 if any (opt is None for opt in optimizers ):
454476 optimizers = [
455477 self .default_torch_optimizer () if opt is None else opt
@@ -480,12 +502,23 @@ def __init__(
480502 f"Got { len (models )} models, and { len (optimizers )} "
481503 " optimizers."
482504 )
505+ if len (schedulers ) != len (optimizers ):
506+ raise ValueError (
507+ "You must define one scheduler for each optimizer."
508+ f"Got { len (schedulers )} schedulers, and { len (optimizers )} "
509+ " optimizers."
510+ )
483511
484512 # initialize the model
485513 self ._pina_models = torch .nn .ModuleList (models )
486514 self ._pina_optimizers = optimizers
487515 self ._pina_schedulers = schedulers
488516
517+ # Set automatic optimization to False.
518+ # For more information on manual optimization see:
519+ # http://lightning.ai/docs/pytorch/stable/model/manual_optimization.html
520+ self .automatic_optimization = False
521+
489522 def configure_optimizers (self ):
490523 """
491524 Optimizer configuration for the solver.
0 commit comments