File tree Expand file tree Collapse file tree 1 file changed +23
-0
lines changed
Expand file tree Collapse file tree 1 file changed +23
-0
lines changed Original file line number Diff line number Diff line change @@ -450,6 +450,16 @@ def __init__(
450450 "one."
451451 )
452452
453+ if optimizers is None :
454+ optimizers = [
455+ self .default_torch_optimizer () for _ in range (len (models ))
456+ ]
457+
458+ if schedulers is None :
459+ schedulers = [
460+ self .default_torch_scheduler () for _ in range (len (models ))
461+ ]
462+
453463 if any (opt is None for opt in optimizers ):
454464 optimizers = [
455465 self .default_torch_optimizer () if opt is None else opt
@@ -480,12 +490,25 @@ def __init__(
480490 f"Got { len (models )} models, and { len (optimizers )} "
481491 " optimizers."
482492 )
493+ if len (schedulers ) != len (optimizers ):
494+ raise ValueError (
495+ "You must define one scheduler for each optimizer."
496+ f"Got { len (schedulers )} schedulers, and { len (optimizers )} "
497+ " optimizers."
498+ )
483499
484500 # initialize the model
485501 self ._pina_models = torch .nn .ModuleList (models )
486502 self ._pina_optimizers = optimizers
487503 self ._pina_schedulers = schedulers
488504
505+ # set automatic optimization to True, this is done on purpuse to trigger
506+ # an error if the user does not uses manual optimization in the
507+ # training step. The following must be override to False and manual
508+ # optimization should be used. For more insights on manual optimization
509+ # see https://lightning.ai/docs/pytorch/stable/model/manual_optimization.html
510+ self .automatic_optimization = True
511+
489512 def configure_optimizers (self ):
490513 """
491514 Optimizer configuration for the solver.
You can’t perform that action at this time.
0 commit comments