Skip to content

Commit c64cbd4

Browse files
authored
Update solver.py
1 parent 03ef90c commit c64cbd4

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

pina/solver/solver.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,16 @@ def __init__(
450450
"one."
451451
)
452452

453+
if optimizers is None:
454+
optimizers = [
455+
self.default_torch_optimizer() for _ in range(len(models))
456+
]
457+
458+
if schedulers is None:
459+
schedulers = [
460+
self.default_torch_scheduler() for _ in range(len(models))
461+
]
462+
453463
if any(opt is None for opt in optimizers):
454464
optimizers = [
455465
self.default_torch_optimizer() if opt is None else opt
@@ -480,12 +490,25 @@ def __init__(
480490
f"Got {len(models)} models, and {len(optimizers)}"
481491
" optimizers."
482492
)
493+
if len(schedulers) != len(optimizers):
494+
raise ValueError(
495+
"You must define one scheduler for each optimizer."
496+
f"Got {len(schedulers)} schedulers, and {len(optimizers)}"
497+
" optimizers."
498+
)
483499

484500
# initialize the model
485501
self._pina_models = torch.nn.ModuleList(models)
486502
self._pina_optimizers = optimizers
487503
self._pina_schedulers = schedulers
488504

505+
# set automatic optimization to True, this is done on purpuse to trigger
506+
# an error if the user does not uses manual optimization in the
507+
# training step. The following must be override to False and manual
508+
# optimization should be used. For more insights on manual optimization
509+
# see https://lightning.ai/docs/pytorch/stable/model/manual_optimization.html
510+
self.automatic_optimization = True
511+
489512
def configure_optimizers(self):
490513
"""
491514
Optimizer configuration for the solver.

0 commit comments

Comments
 (0)