Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 985304e

Browse files
authored
LearningRateFunctionModifier Implementation for PyTorch (#288)
* Add in initial modifier for lr function in PyTorch * add param_groups awareness to set lr modifier * bug fixes * add in unit tests and bug fixes for lr function modifier * style fix for lr modifiers diff * fix broken test for set_optim_learning_rate * remove test recipe yaml for lr funciton modifier * update docs for param_groups in set lr modifier Co-authored-by: Mark Kurtz <[email protected]>
1 parent 70f93b7 commit 985304e

File tree

3 files changed

+536
-17
lines changed

3 files changed

+536
-17
lines changed

src/sparseml/pytorch/optim/modifier_lr.py

Lines changed: 296 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import math
2121
import sys
22-
from typing import Dict, List, Union
22+
from typing import Dict, List, Optional, Tuple, Union
2323

2424
from torch.nn import Module
2525
from torch.optim.lr_scheduler import (
@@ -39,13 +39,17 @@
3939
)
4040
from sparseml.pytorch.utils import (
4141
BaseLogger,
42-
get_optim_learning_rate,
42+
get_optim_groups_learning_rates,
4343
set_optim_learning_rate,
4444
)
4545
from sparseml.utils import ALL_TOKEN, convert_to_bool
4646

4747

48-
__all__ = ["SetLearningRateModifier", "LearningRateModifier"]
48+
__all__ = [
49+
"SetLearningRateModifier",
50+
"LearningRateFunctionModifier",
51+
"LearningRateModifier",
52+
]
4953

5054

5155
CONSTRUCTORS = {
@@ -57,12 +61,16 @@
5761

5862

5963
def _log_lr(
60-
cur_lr: float, loggers: List[BaseLogger], epoch: float, steps_per_epoch: int
64+
group_lrs: List[Tuple[str, float]],
65+
loggers: List[BaseLogger],
66+
epoch: float,
67+
steps_per_epoch: int,
6168
):
6269
step = round(epoch) if steps_per_epoch <= 0 else round(epoch * steps_per_epoch)
6370

6471
for logger in loggers:
65-
logger.log_scalar("Modifier LR", cur_lr, step)
72+
for (group_name, group_lr) in group_lrs:
73+
logger.log_scalar(f"LearningRateModifier/{group_name}", group_lr, step)
6674

6775

6876
@PyTorchModifierYAML()
@@ -93,6 +101,7 @@ class SetLearningRateModifier(ScheduledModifier, SetLearningRate):
93101
def __init__(
94102
self,
95103
learning_rate: Union[float, None],
104+
param_groups: Optional[List[int]] = None,
96105
start_epoch: float = -1.0,
97106
end_epoch: float = -1.0,
98107
log_types: Union[str, List[str]] = ALL_TOKEN,
@@ -105,12 +114,29 @@ def __init__(
105114
end_epoch=-1,
106115
end_comparator=None,
107116
)
117+
self._param_groups = param_groups
108118
self._lr_set = False
109119
self._applied = -1.0
110120
self._constant_logging = convert_to_bool(constant_logging)
111121
self._last_logged_lr = None
112122
self._last_logged_epoch = None
113123

124+
@ModifierProp()
125+
def param_groups(self) -> Optional[List[int]]:
126+
"""
127+
:return: The param group indices to set the lr for within the optimizer,
128+
if not set will set the lr for all param groups
129+
"""
130+
return self._param_groups
131+
132+
@param_groups.setter
133+
def param_groups(self, value: Optional[List[int]]):
134+
"""
135+
:param value: The param group indices to set the lr for within the optimizer,
136+
if not set will set the lr for all param groups
137+
"""
138+
self._param_groups = value
139+
114140
@ModifierProp()
115141
def constant_logging(self) -> bool:
116142
"""
@@ -165,16 +191,28 @@ def log_update(
165191
(calculate batch number using this and epoch)
166192
"""
167193
super().log_update(module, optimizer, epoch, steps_per_epoch)
168-
current_lr = get_optim_learning_rate(optimizer)
194+
group_lrs = [
195+
(f"ParamGroup{index}", lr)
196+
for (index, lr) in enumerate(get_optim_groups_learning_rates(optimizer))
197+
if not self.param_groups or index in self.param_groups
198+
]
199+
200+
if not group_lrs:
201+
raise ValueError(
202+
"Could not find param groups in the optimizer "
203+
f"for given param_groups {self.param_groups}"
204+
)
205+
206+
current_lr = group_lrs[-1][1]
169207

170208
if (
171209
self._constant_logging
172-
or current_lr != self._last_logged_lr
210+
or self._last_logged_lr != current_lr
173211
or math.floor(epoch) != self._last_logged_epoch
174212
):
175213
self._last_logged_lr = current_lr
176214
self._last_logged_epoch = math.floor(epoch)
177-
_log_lr(current_lr, self.loggers, epoch, steps_per_epoch)
215+
_log_lr(group_lrs, self.loggers, epoch, steps_per_epoch)
178216

179217
def _check_set_lr(self, optimizer: Optimizer, epoch: float):
180218
if (
@@ -185,11 +223,249 @@ def _check_set_lr(self, optimizer: Optimizer, epoch: float):
185223
and not self._lr_set
186224
and self._learning_rate is not None
187225
):
188-
set_optim_learning_rate(optimizer, self.learning_rate)
189-
self._applied = self._learning_rate
226+
for (index, group) in enumerate(optimizer.param_groups):
227+
if not self.param_groups or index in self.param_groups:
228+
group["lr"] = self.learning_rate
229+
self._applied = self.learning_rate
190230
self._lr_set = True
191231

192232

233+
@PyTorchModifierYAML()
234+
class LearningRateFunctionModifier(ScheduledUpdateModifier):
235+
"""
236+
Modifier to set the learning rate based on supported math functions scaling between
237+
an init_lr and a final_lr.
238+
Any time an update point is reached, the LR is updated for the parameters groups
239+
in the optimizer.
240+
Specific parameter groups can be targeted for the optimizer as well.
241+
242+
| Sample yaml:
243+
| !LearningRateFunctionModifier
244+
| start_epoch: 0.0
245+
| end_epoch: 10.0
246+
| lr_func: linear
247+
| init_lr: 0.1
248+
| final_lr: 0.001
249+
250+
:param lr_func: The name of the lr function to use: [linear, cosine]
251+
:param init_lr: The initial learning rate to use once this modifier starts
252+
:param init_lr: The final learning rate to use once this modifier starts
253+
:param start_epoch: The epoch to start the modifier at
254+
(set to -1.0 so it starts immediately)
255+
:param end_epoch: The epoch to end the modifier at,
256+
(set to -1.0 so it doesn't end)
257+
:param_groups: The param group indices to set the lr for within the optimizer,
258+
if not set will set the lr for all param groups
259+
:param update_frequency: unused and should not be set
260+
:param log_types: The loggers to allow the learning rate to be logged to,
261+
default is __ALL__
262+
:param constant_logging: True to constantly log on every step,
263+
False to only log on an LR change and min once per epoch, default False
264+
"""
265+
266+
def __init__(
267+
self,
268+
lr_func: str,
269+
init_lr: float,
270+
final_lr: float,
271+
start_epoch: float,
272+
end_epoch: float,
273+
param_groups: Optional[List[int]] = None,
274+
update_frequency: float = -1.0,
275+
log_types: Union[str, List[str]] = ALL_TOKEN,
276+
):
277+
super().__init__(
278+
log_types=log_types,
279+
start_epoch=start_epoch,
280+
end_epoch=end_epoch,
281+
update_frequency=-1.0,
282+
end_comparator=1,
283+
)
284+
self._lr_func = lr_func
285+
self._init_lr = init_lr
286+
self._final_lr = final_lr
287+
self._param_groups = param_groups
288+
self._learning_rate = None
289+
self._last_applied_lr = None
290+
self._last_logged_lr = None
291+
self._last_logged_epoch = None
292+
self.validate()
293+
294+
@ModifierProp()
295+
def lr_func(self) -> str:
296+
"""
297+
:return: The name of the lr function to use: [linear, cosine]
298+
"""
299+
return self._lr_func
300+
301+
@lr_func.setter
302+
def lr_func(self, value: str):
303+
"""
304+
:param value: The name of the lr function to use: [linear, cosine]
305+
"""
306+
self._lr_func = value
307+
self.validate()
308+
309+
@ModifierProp()
310+
def init_lr(self) -> float:
311+
"""
312+
:return: The initial learning rate to use once this modifier starts
313+
"""
314+
return self._init_lr
315+
316+
@init_lr.setter
317+
def init_lr(self, value: float):
318+
"""
319+
:param value: The initial learning rate to use once this modifier starts
320+
"""
321+
self._init_lr = value
322+
self.validate()
323+
324+
@ModifierProp()
325+
def final_lr(self) -> float:
326+
"""
327+
:return: The final learning rate to use once this modifier starts
328+
"""
329+
return self._final_lr
330+
331+
@final_lr.setter
332+
def final_lr(self, value: float):
333+
"""
334+
:param value: The final learning rate to use once this modifier starts
335+
"""
336+
self._final_lr = value
337+
self.validate()
338+
339+
@ModifierProp()
340+
def param_groups(self) -> Optional[List[int]]:
341+
"""
342+
:return: The param group indices to set the lr for within the optimizer,
343+
if not set will set the lr for all param groups
344+
"""
345+
return self._param_groups
346+
347+
@param_groups.setter
348+
def param_groups(self, value: Optional[List[int]]):
349+
"""
350+
:param value: The param group indices to set the lr for within the optimizer,
351+
if not set will set the lr for all param groups
352+
"""
353+
self._param_groups = value
354+
self.validate()
355+
356+
def update(
357+
self, module: Module, optimizer: Optimizer, epoch: float, steps_per_epoch: int
358+
):
359+
"""
360+
Updates the LR based on the given epoch for the optimizer
361+
362+
:param module: module to modify
363+
:param optimizer: optimizer to modify
364+
:param epoch: current epoch and progress within the current epoch
365+
:param steps_per_epoch: number of steps taken within each epoch
366+
(calculate batch number using this and epoch)
367+
"""
368+
super().update(module, optimizer, epoch, steps_per_epoch)
369+
lambad_func = getattr(LearningRateFunctionModifier, f"_{self._lr_func}")
370+
self._learning_rate = lambad_func(self, epoch)
371+
set_optim_learning_rate(optimizer, self._learning_rate, self.param_groups)
372+
373+
def log_update(
374+
self, module: Module, optimizer: Optimizer, epoch: float, steps_per_epoch: int
375+
):
376+
"""
377+
Check whether to log an update for the learning rate of the modifier.
378+
Checks for a change in the LR or epoch before logging
379+
380+
:param module: module to modify
381+
:param optimizer: optimizer to modify
382+
:param epoch: current epoch and progress within the current epoch
383+
:param steps_per_epoch: number of steps taken within each epoch
384+
(calculate batch number using this and epoch)
385+
"""
386+
super().log_update(module, optimizer, epoch, steps_per_epoch)
387+
group_lrs = [
388+
(f"ParamGroup{index}", lr)
389+
for (index, lr) in enumerate(get_optim_groups_learning_rates(optimizer))
390+
if not self.param_groups or index in self.param_groups
391+
]
392+
393+
if not group_lrs:
394+
raise ValueError(
395+
"Could not find param groups in the optimizer "
396+
f"for given param_groups {self.param_groups}"
397+
)
398+
399+
current_lr = group_lrs[-1][1]
400+
401+
if (
402+
current_lr != self._last_logged_lr
403+
or math.floor(epoch) != self._last_logged_epoch
404+
):
405+
_log_lr(group_lrs, self.loggers, epoch, steps_per_epoch)
406+
self._last_logged_lr = current_lr
407+
self._last_logged_epoch = math.floor(epoch)
408+
409+
def validate(self):
410+
"""
411+
Validate the values of the params for the current instance are valid
412+
"""
413+
lr_funcs = ["linear", "cosine"]
414+
if self.lr_func not in lr_funcs:
415+
raise ValueError(f"lr_func must be one of {lr_funcs}")
416+
417+
if (
418+
(not self.init_lr and self.init_lr != 0)
419+
or self.init_lr < 0.0
420+
or self.init_lr > 1.0
421+
):
422+
raise ValueError(
423+
f"init_lr must be within range [0.0, 1.0], given {self.init_lr}"
424+
)
425+
426+
if (
427+
(not self.final_lr and self.final_lr != 0)
428+
or self.final_lr < 0.0
429+
or self.final_lr > 1.0
430+
):
431+
raise ValueError(
432+
f"final_lr must be within range [0.0, 1.0], given {self.final_lr}"
433+
)
434+
435+
if self.update_frequency != -1.0:
436+
raise ValueError("update_frequency must be kept at -1.0")
437+
438+
def _linear(self, epoch: float) -> float:
439+
# y = y1 + ((x – x1) / (x2 – x1)) * (y2 – y1)
440+
start = self.start_epoch if self.start_epoch > 0 else 0.0
441+
end = self.end_epoch
442+
443+
return self.init_lr + ((epoch - start) / (end - start)) * (
444+
self.final_lr - self.init_lr
445+
)
446+
447+
def _cosine(self, epoch: float) -> float:
448+
start = self.start_epoch if self.start_epoch > 0 else 0.0
449+
end = self.end_epoch
450+
451+
# scale x to [0-1] for use with cosine
452+
x_norm = (epoch - start) / (end - start)
453+
454+
# conditional to support cosine down to a value and up to a value
455+
if self.final_lr < self.init_lr:
456+
y_range = self.init_lr - self.final_lr
457+
y_shift = self.final_lr
458+
x_shift = 0
459+
else:
460+
y_range = self.final_lr - self.init_lr
461+
y_shift = self.init_lr
462+
x_shift = math.pi
463+
464+
return (
465+
math.cos(x_norm * math.pi + x_shift) * y_range / 2 + y_range / 2 + y_shift
466+
)
467+
468+
193469
@PyTorchModifierYAML()
194470
class LearningRateModifier(ScheduledUpdateModifier, LearningRate):
195471
"""
@@ -337,7 +613,15 @@ def log_update(
337613
(calculate batch number using this and epoch)
338614
"""
339615
super().log_update(module, optimizer, epoch, steps_per_epoch)
340-
current_lr = get_optim_learning_rate(optimizer)
616+
group_lrs = [
617+
(f"ParamGroup{index}", lr)
618+
for (index, lr) in enumerate(get_optim_groups_learning_rates(optimizer))
619+
]
620+
621+
if not group_lrs:
622+
raise ValueError("Could not find any param groups in the optimizer")
623+
624+
current_lr = group_lrs[-1][1]
341625

342626
if (
343627
self._constant_logging
@@ -346,7 +630,7 @@ def log_update(
346630
):
347631
self._last_logged_lr = current_lr
348632
self._last_logged_epoch = math.floor(epoch)
349-
_log_lr(current_lr, self.loggers, epoch, steps_per_epoch)
633+
_log_lr(group_lrs, self.loggers, epoch, steps_per_epoch)
350634

351635
def validate(self):
352636
"""

0 commit comments

Comments
 (0)