@@ -119,21 +119,8 @@ def configure_schedulers(
119119 lr_schedulers = []
120120 default_config = _get_default_scheduler_config ()
121121 for scheduler in schedulers :
122- if isinstance (scheduler , dict ):
123- # check provided keys
124- extra_keys = [k for k in scheduler .keys () if k not in default_config .keys ()]
125- if extra_keys :
126- rank_zero_warn (f"Found unsupported keys in the lr scheduler dict: { extra_keys } " , RuntimeWarning )
127- if "scheduler" not in scheduler :
128- raise MisconfigurationException (
129- 'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler'
130- )
131- if "interval" in scheduler and scheduler ["interval" ] not in ("step" , "epoch" ):
132- raise MisconfigurationException (
133- f'The "interval" key in lr scheduler dict must be "step" or "epoch"'
134- f' but is "{ scheduler ["interval" ]} "'
135- )
136- if is_manual_optimization :
122+ if is_manual_optimization :
123+ if isinstance (scheduler , dict ):
137124 invalid_keys = {"interval" , "frequency" , "reduce_on_plateau" , "monitor" , "strict" }
138125 keys_to_warn = [k for k in scheduler .keys () if k in invalid_keys ]
139126
@@ -144,30 +131,49 @@ def configure_schedulers(
144131 RuntimeWarning ,
145132 )
146133
147- scheduler ["reduce_on_plateau" ] = isinstance (
148- scheduler ["scheduler" ], optim .lr_scheduler .ReduceLROnPlateau
149- )
150- if scheduler ["reduce_on_plateau" ] and scheduler .get ("monitor" , None ) is None :
151- raise MisconfigurationException (
152- "The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used."
153- ' For example: {"optimizer": optimizer, "lr_scheduler":'
154- ' {"scheduler": scheduler, "monitor": "your_loss"}}'
134+ scheduler = {key : scheduler [key ] for key in scheduler if key not in invalid_keys }
135+ lr_schedulers .append ({** default_config , ** scheduler })
136+ else :
137+ lr_schedulers .append ({** default_config , "scheduler" : scheduler })
138+ else :
139+ if isinstance (scheduler , dict ):
140+ # check provided keys
141+ extra_keys = [k for k in scheduler .keys () if k not in default_config .keys ()]
142+ if extra_keys :
143+ rank_zero_warn (f"Found unsupported keys in the lr scheduler dict: { extra_keys } " , RuntimeWarning )
144+ if "scheduler" not in scheduler :
145+ raise MisconfigurationException (
146+ 'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler'
147+ )
148+ if "interval" in scheduler and scheduler ["interval" ] not in ("step" , "epoch" ):
149+ raise MisconfigurationException (
150+ 'The "interval" key in lr scheduler dict must be "step" or "epoch"'
151+ f' but is "{ scheduler ["interval" ]} "'
152+ )
153+ scheduler ["reduce_on_plateau" ] = isinstance (
154+ scheduler ["scheduler" ], optim .lr_scheduler .ReduceLROnPlateau
155155 )
156- lr_schedulers .append ({** default_config , ** scheduler })
157- elif isinstance (scheduler , optim .lr_scheduler .ReduceLROnPlateau ):
158- if monitor is None :
159- raise MisconfigurationException (
160- "`configure_optimizers` must include a monitor when a `ReduceLROnPlateau` scheduler is used."
161- " For example:"
162- ' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
156+ if scheduler ["reduce_on_plateau" ] and scheduler .get ("monitor" , None ) is None :
157+ raise MisconfigurationException (
158+ "The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used."
159+ ' For example: {"optimizer": optimizer, "lr_scheduler":'
160+ ' {"scheduler": scheduler, "monitor": "your_loss"}}'
161+ )
162+ lr_schedulers .append ({** default_config , ** scheduler })
163+ elif isinstance (scheduler , optim .lr_scheduler .ReduceLROnPlateau ):
164+ if monitor is None :
165+ raise MisconfigurationException (
166+ "`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`"
167+ " scheduler is used. For example:"
168+ ' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
169+ )
170+ lr_schedulers .append (
171+ {** default_config , "scheduler" : scheduler , "reduce_on_plateau" : True , "monitor" : monitor }
163172 )
164- lr_schedulers .append (
165- {** default_config , "scheduler" : scheduler , "reduce_on_plateau" : True , "monitor" : monitor }
166- )
167- elif isinstance (scheduler , optim .lr_scheduler ._LRScheduler ):
168- lr_schedulers .append ({** default_config , "scheduler" : scheduler })
169- else :
170- raise ValueError (f'The provided lr scheduler "{ scheduler } " is invalid' )
173+ elif isinstance (scheduler , optim .lr_scheduler ._LRScheduler ):
174+ lr_schedulers .append ({** default_config , "scheduler" : scheduler })
175+ else :
176+ raise ValueError (f'The provided lr scheduler "{ scheduler } " is invalid' )
171177 return lr_schedulers
172178
173179
0 commit comments