1919
2020import math
2121import sys
22- from typing import Dict , List , Union
22+ from typing import Dict , List , Optional , Tuple , Union
2323
2424from torch .nn import Module
2525from torch .optim .lr_scheduler import (
3939)
4040from sparseml .pytorch .utils import (
4141 BaseLogger ,
42- get_optim_learning_rate ,
42+ get_optim_groups_learning_rates ,
4343 set_optim_learning_rate ,
4444)
4545from sparseml .utils import ALL_TOKEN , convert_to_bool
4646
4747
48- __all__ = ["SetLearningRateModifier" , "LearningRateModifier" ]
48+ __all__ = [
49+ "SetLearningRateModifier" ,
50+ "LearningRateFunctionModifier" ,
51+ "LearningRateModifier" ,
52+ ]
4953
5054
5155CONSTRUCTORS = {
5761
5862
5963def _log_lr (
60- cur_lr : float , loggers : List [BaseLogger ], epoch : float , steps_per_epoch : int
64+ group_lrs : List [Tuple [str , float ]],
65+ loggers : List [BaseLogger ],
66+ epoch : float ,
67+ steps_per_epoch : int ,
6168):
6269 step = round (epoch ) if steps_per_epoch <= 0 else round (epoch * steps_per_epoch )
6370
6471 for logger in loggers :
65- logger .log_scalar ("Modifier LR" , cur_lr , step )
72+ for (group_name , group_lr ) in group_lrs :
73+ logger .log_scalar (f"LearningRateModifier/{ group_name } " , group_lr , step )
6674
6775
6876@PyTorchModifierYAML ()
@@ -93,6 +101,7 @@ class SetLearningRateModifier(ScheduledModifier, SetLearningRate):
93101 def __init__ (
94102 self ,
95103 learning_rate : Union [float , None ],
104+ param_groups : Optional [List [int ]] = None ,
96105 start_epoch : float = - 1.0 ,
97106 end_epoch : float = - 1.0 ,
98107 log_types : Union [str , List [str ]] = ALL_TOKEN ,
@@ -105,12 +114,29 @@ def __init__(
105114 end_epoch = - 1 ,
106115 end_comparator = None ,
107116 )
117+ self ._param_groups = param_groups
108118 self ._lr_set = False
109119 self ._applied = - 1.0
110120 self ._constant_logging = convert_to_bool (constant_logging )
111121 self ._last_logged_lr = None
112122 self ._last_logged_epoch = None
113123
124+ @ModifierProp ()
125+ def param_groups (self ) -> Optional [List [int ]]:
126+ """
127+ :return: The param group indices to set the lr for within the optimizer,
128+ if not set will set the lr for all param groups
129+ """
130+ return self ._param_groups
131+
132+ @param_groups .setter
133+ def param_groups (self , value : Optional [List [int ]]):
134+ """
135+ :param value: The param group indices to set the lr for within the optimizer,
136+ if not set will set the lr for all param groups
137+ """
138+ self ._param_groups = value
139+
114140 @ModifierProp ()
115141 def constant_logging (self ) -> bool :
116142 """
@@ -165,16 +191,28 @@ def log_update(
165191 (calculate batch number using this and epoch)
166192 """
167193 super ().log_update (module , optimizer , epoch , steps_per_epoch )
168- current_lr = get_optim_learning_rate (optimizer )
194+ group_lrs = [
195+ (f"ParamGroup{ index } " , lr )
196+ for (index , lr ) in enumerate (get_optim_groups_learning_rates (optimizer ))
197+ if not self .param_groups or index in self .param_groups
198+ ]
199+
200+ if not group_lrs :
201+ raise ValueError (
202+ "Could not find param groups in the optimizer "
203+ f"for given param_groups { self .param_groups } "
204+ )
205+
206+ current_lr = group_lrs [- 1 ][1 ]
169207
170208 if (
171209 self ._constant_logging
172- or current_lr != self . _last_logged_lr
210+ or self . _last_logged_lr != current_lr
173211 or math .floor (epoch ) != self ._last_logged_epoch
174212 ):
175213 self ._last_logged_lr = current_lr
176214 self ._last_logged_epoch = math .floor (epoch )
177- _log_lr (current_lr , self .loggers , epoch , steps_per_epoch )
215+ _log_lr (group_lrs , self .loggers , epoch , steps_per_epoch )
178216
179217 def _check_set_lr (self , optimizer : Optimizer , epoch : float ):
180218 if (
@@ -185,11 +223,249 @@ def _check_set_lr(self, optimizer: Optimizer, epoch: float):
185223 and not self ._lr_set
186224 and self ._learning_rate is not None
187225 ):
188- set_optim_learning_rate (optimizer , self .learning_rate )
189- self ._applied = self ._learning_rate
226+ for (index , group ) in enumerate (optimizer .param_groups ):
227+ if not self .param_groups or index in self .param_groups :
228+ group ["lr" ] = self .learning_rate
229+ self ._applied = self .learning_rate
190230 self ._lr_set = True
191231
192232
233+ @PyTorchModifierYAML ()
234+ class LearningRateFunctionModifier (ScheduledUpdateModifier ):
235+ """
236+ Modifier to set the learning rate based on supported math functions scaling between
237+ an init_lr and a final_lr.
238+ Any time an update point is reached, the LR is updated for the parameters groups
239+ in the optimizer.
240+ Specific parameter groups can be targeted for the optimizer as well.
241+
242+ | Sample yaml:
243+ | !LearningRateFunctionModifier
244+ | start_epoch: 0.0
245+ | end_epoch: 10.0
246+ | lr_func: linear
247+ | init_lr: 0.1
248+ | final_lr: 0.001
249+
250+ :param lr_func: The name of the lr function to use: [linear, cosine]
251+ :param init_lr: The initial learning rate to use once this modifier starts
252+ :param init_lr: The final learning rate to use once this modifier starts
253+ :param start_epoch: The epoch to start the modifier at
254+ (set to -1.0 so it starts immediately)
255+ :param end_epoch: The epoch to end the modifier at,
256+ (set to -1.0 so it doesn't end)
257+ :param_groups: The param group indices to set the lr for within the optimizer,
258+ if not set will set the lr for all param groups
259+ :param update_frequency: unused and should not be set
260+ :param log_types: The loggers to allow the learning rate to be logged to,
261+ default is __ALL__
262+ :param constant_logging: True to constantly log on every step,
263+ False to only log on an LR change and min once per epoch, default False
264+ """
265+
266+ def __init__ (
267+ self ,
268+ lr_func : str ,
269+ init_lr : float ,
270+ final_lr : float ,
271+ start_epoch : float ,
272+ end_epoch : float ,
273+ param_groups : Optional [List [int ]] = None ,
274+ update_frequency : float = - 1.0 ,
275+ log_types : Union [str , List [str ]] = ALL_TOKEN ,
276+ ):
277+ super ().__init__ (
278+ log_types = log_types ,
279+ start_epoch = start_epoch ,
280+ end_epoch = end_epoch ,
281+ update_frequency = - 1.0 ,
282+ end_comparator = 1 ,
283+ )
284+ self ._lr_func = lr_func
285+ self ._init_lr = init_lr
286+ self ._final_lr = final_lr
287+ self ._param_groups = param_groups
288+ self ._learning_rate = None
289+ self ._last_applied_lr = None
290+ self ._last_logged_lr = None
291+ self ._last_logged_epoch = None
292+ self .validate ()
293+
294+ @ModifierProp ()
295+ def lr_func (self ) -> str :
296+ """
297+ :return: The name of the lr function to use: [linear, cosine]
298+ """
299+ return self ._lr_func
300+
301+ @lr_func .setter
302+ def lr_func (self , value : str ):
303+ """
304+ :param value: The name of the lr function to use: [linear, cosine]
305+ """
306+ self ._lr_func = value
307+ self .validate ()
308+
309+ @ModifierProp ()
310+ def init_lr (self ) -> float :
311+ """
312+ :return: The initial learning rate to use once this modifier starts
313+ """
314+ return self ._init_lr
315+
316+ @init_lr .setter
317+ def init_lr (self , value : float ):
318+ """
319+ :param value: The initial learning rate to use once this modifier starts
320+ """
321+ self ._init_lr = value
322+ self .validate ()
323+
324+ @ModifierProp ()
325+ def final_lr (self ) -> float :
326+ """
327+ :return: The final learning rate to use once this modifier starts
328+ """
329+ return self ._final_lr
330+
331+ @final_lr .setter
332+ def final_lr (self , value : float ):
333+ """
334+ :param value: The final learning rate to use once this modifier starts
335+ """
336+ self ._final_lr = value
337+ self .validate ()
338+
339+ @ModifierProp ()
340+ def param_groups (self ) -> Optional [List [int ]]:
341+ """
342+ :return: The param group indices to set the lr for within the optimizer,
343+ if not set will set the lr for all param groups
344+ """
345+ return self ._param_groups
346+
347+ @param_groups .setter
348+ def param_groups (self , value : Optional [List [int ]]):
349+ """
350+ :param value: The param group indices to set the lr for within the optimizer,
351+ if not set will set the lr for all param groups
352+ """
353+ self ._param_groups = value
354+ self .validate ()
355+
356+ def update (
357+ self , module : Module , optimizer : Optimizer , epoch : float , steps_per_epoch : int
358+ ):
359+ """
360+ Updates the LR based on the given epoch for the optimizer
361+
362+ :param module: module to modify
363+ :param optimizer: optimizer to modify
364+ :param epoch: current epoch and progress within the current epoch
365+ :param steps_per_epoch: number of steps taken within each epoch
366+ (calculate batch number using this and epoch)
367+ """
368+ super ().update (module , optimizer , epoch , steps_per_epoch )
369+ lambad_func = getattr (LearningRateFunctionModifier , f"_{ self ._lr_func } " )
370+ self ._learning_rate = lambad_func (self , epoch )
371+ set_optim_learning_rate (optimizer , self ._learning_rate , self .param_groups )
372+
373+ def log_update (
374+ self , module : Module , optimizer : Optimizer , epoch : float , steps_per_epoch : int
375+ ):
376+ """
377+ Check whether to log an update for the learning rate of the modifier.
378+ Checks for a change in the LR or epoch before logging
379+
380+ :param module: module to modify
381+ :param optimizer: optimizer to modify
382+ :param epoch: current epoch and progress within the current epoch
383+ :param steps_per_epoch: number of steps taken within each epoch
384+ (calculate batch number using this and epoch)
385+ """
386+ super ().log_update (module , optimizer , epoch , steps_per_epoch )
387+ group_lrs = [
388+ (f"ParamGroup{ index } " , lr )
389+ for (index , lr ) in enumerate (get_optim_groups_learning_rates (optimizer ))
390+ if not self .param_groups or index in self .param_groups
391+ ]
392+
393+ if not group_lrs :
394+ raise ValueError (
395+ "Could not find param groups in the optimizer "
396+ f"for given param_groups { self .param_groups } "
397+ )
398+
399+ current_lr = group_lrs [- 1 ][1 ]
400+
401+ if (
402+ current_lr != self ._last_logged_lr
403+ or math .floor (epoch ) != self ._last_logged_epoch
404+ ):
405+ _log_lr (group_lrs , self .loggers , epoch , steps_per_epoch )
406+ self ._last_logged_lr = current_lr
407+ self ._last_logged_epoch = math .floor (epoch )
408+
409+ def validate (self ):
410+ """
411+ Validate the values of the params for the current instance are valid
412+ """
413+ lr_funcs = ["linear" , "cosine" ]
414+ if self .lr_func not in lr_funcs :
415+ raise ValueError (f"lr_func must be one of { lr_funcs } " )
416+
417+ if (
418+ (not self .init_lr and self .init_lr != 0 )
419+ or self .init_lr < 0.0
420+ or self .init_lr > 1.0
421+ ):
422+ raise ValueError (
423+ f"init_lr must be within range [0.0, 1.0], given { self .init_lr } "
424+ )
425+
426+ if (
427+ (not self .final_lr and self .final_lr != 0 )
428+ or self .final_lr < 0.0
429+ or self .final_lr > 1.0
430+ ):
431+ raise ValueError (
432+ f"final_lr must be within range [0.0, 1.0], given { self .final_lr } "
433+ )
434+
435+ if self .update_frequency != - 1.0 :
436+ raise ValueError ("update_frequency must be kept at -1.0" )
437+
438+ def _linear (self , epoch : float ) -> float :
439+ # y = y1 + ((x – x1) / (x2 – x1)) * (y2 – y1)
440+ start = self .start_epoch if self .start_epoch > 0 else 0.0
441+ end = self .end_epoch
442+
443+ return self .init_lr + ((epoch - start ) / (end - start )) * (
444+ self .final_lr - self .init_lr
445+ )
446+
447+ def _cosine (self , epoch : float ) -> float :
448+ start = self .start_epoch if self .start_epoch > 0 else 0.0
449+ end = self .end_epoch
450+
451+ # scale x to [0-1] for use with cosine
452+ x_norm = (epoch - start ) / (end - start )
453+
454+ # conditional to support cosine down to a value and up to a value
455+ if self .final_lr < self .init_lr :
456+ y_range = self .init_lr - self .final_lr
457+ y_shift = self .final_lr
458+ x_shift = 0
459+ else :
460+ y_range = self .final_lr - self .init_lr
461+ y_shift = self .init_lr
462+ x_shift = math .pi
463+
464+ return (
465+ math .cos (x_norm * math .pi + x_shift ) * y_range / 2 + y_range / 2 + y_shift
466+ )
467+
468+
193469@PyTorchModifierYAML ()
194470class LearningRateModifier (ScheduledUpdateModifier , LearningRate ):
195471 """
@@ -337,7 +613,15 @@ def log_update(
337613 (calculate batch number using this and epoch)
338614 """
339615 super ().log_update (module , optimizer , epoch , steps_per_epoch )
340- current_lr = get_optim_learning_rate (optimizer )
616+ group_lrs = [
617+ (f"ParamGroup{ index } " , lr )
618+ for (index , lr ) in enumerate (get_optim_groups_learning_rates (optimizer ))
619+ ]
620+
621+ if not group_lrs :
622+ raise ValueError ("Could not find any param groups in the optimizer" )
623+
624+ current_lr = group_lrs [- 1 ][1 ]
341625
342626 if (
343627 self ._constant_logging
@@ -346,7 +630,7 @@ def log_update(
346630 ):
347631 self ._last_logged_lr = current_lr
348632 self ._last_logged_epoch = math .floor (epoch )
349- _log_lr (current_lr , self .loggers , epoch , steps_per_epoch )
633+ _log_lr (group_lrs , self .loggers , epoch , steps_per_epoch )
350634
351635 def validate (self ):
352636 """
0 commit comments