Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 93e0415

Browse files
authored
[cherry-pick] Cherry pick pytorch sched fixes to release 0.3 (#234)
* Fix bugs with ScheduledOptimizer and pytorch scheduler refactor (#233) * Fix bugs with ScheduledOptimizer and pytorch scheduler refactor * fix test cases for updated log_update initialization * revert changes to classification notebook * Pytorch manager modify fixes (#235)
1 parent 733bef1 commit 93e0415

File tree

4 files changed

+14
-26
lines changed

4 files changed

+14
-26
lines changed

src/sparseml/pytorch/optim/manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,13 +368,13 @@ def modify(
368368
epoch = self._initialize_epoch
369369

370370
if not self.initialized:
371-
self.intialize(module, epoch)
371+
self.initialize(module, epoch)
372372

373373
if wrap_optim is None:
374374
wrap_optim = optimizer
375375

376376
return RecipeManagerStepWrapper(
377-
wrap_optim, self, optimizer, module, epoch, steps_per_epoch
377+
wrap_optim, optimizer, module, self, epoch, steps_per_epoch
378378
)
379379

380380
def finalize(

src/sparseml/pytorch/optim/modifier.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,7 @@ def initialize(
196196
for individual modifiers.
197197
"""
198198
self._initialized = True
199-
200-
if loggers:
201-
self.initialize_loggers(loggers)
199+
self.initialize_loggers(loggers)
202200

203201
def initialize_loggers(self, loggers: Union[None, List[BaseLogger]]):
204202
"""

src/sparseml/pytorch/optim/optimizer.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,11 @@ def __init__(
8181
):
8282
# do not call into super since this instance is not passing all calls to
8383
# the nested optimizer
84-
warnings.warn(
85-
"ScheduledOptimizer is deprecated and will be deleted in the future. "
86-
"Please replace with manager.modify",
87-
UserWarning,
88-
)
84+
# warnings.warn(
85+
# "ScheduledOptimizer is deprecated and will be deleted in the future. "
86+
# "Please replace with manager.modify",
87+
# UserWarning,
88+
# ) TODO: uncomment in next release once docs are ready
8989

9090
manager.initialize(module, epoch=0.0, loggers=loggers)
9191
self._wrapper = RecipeManagerStepWrapper(
@@ -107,7 +107,7 @@ def __getattr__(self, item):
107107
if item in self.__dict__:
108108
return getattr(self, item)
109109

110-
return getattr(self._wrapped, item)
110+
return getattr(self._wrapper.wrapped_optimizer, item)
111111

112112
def __setattr__(self, key, value):
113113
if key in [
@@ -118,23 +118,23 @@ def __setattr__(self, key, value):
118118
]:
119119
super().__setattr__(key, value)
120120
else:
121-
setattr(self._optimizer, key, value)
121+
setattr(self._wrapper.wrapped_optimizer, key, value)
122122

123123
@property
124124
def learning_rate(self) -> float:
125125
"""
126126
:return: convenience function to get the first learning rate for any of
127127
the param groups in the optimizer
128128
"""
129-
return get_optim_learning_rate(self._optimizer)
129+
return get_optim_learning_rate(self._wrapper.wrapped_optimizer)
130130

131131
@learning_rate.setter
132132
def learning_rate(self, value: float):
133133
"""
134134
:param value: the learning rate to set for the optimizer,
135135
will set all param groups in the optim to this value
136136
"""
137-
set_optim_learning_rate(self._optimizer, value)
137+
set_optim_learning_rate(self._wrapper.wrapped_optimizer, value)
138138

139139
@property
140140
def manager(self) -> ScheduledModifierManager:
@@ -144,10 +144,10 @@ def manager(self) -> ScheduledModifierManager:
144144
return self._wrapper.wrapped_manager
145145

146146
def manager_state_dict(self):
147-
return self._manager.state_dict()
147+
return self._wrapper.wrapped_manager.state_dict()
148148

149149
def load_manager_state_dict(self, state_dict):
150-
self._manager.load_state_dict(state_dict)
150+
self._wrapper.wrapped_manager.load_state_dict(state_dict)
151151

152152
def step(self, closure=None):
153153
"""

tests/sparseml/pytorch/optim/test_modifier.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -216,11 +216,6 @@ def test_log_update(
216216
model = model_lambda()
217217
optimizer = optim_lambda(model)
218218

219-
with pytest.raises(RuntimeError):
220-
modifier.log_update(model, optimizer, test_epoch, test_steps_per_epoch)
221-
222-
self.initialize_helper(modifier, model, log_initialize=False)
223-
224219
with pytest.raises(RuntimeError):
225220
modifier.log_update(model, optimizer, test_epoch, test_steps_per_epoch)
226221

@@ -496,11 +491,6 @@ def test_scheduled_log_update(
496491
model = model_lambda()
497492
optimizer = optim_lambda(model)
498493

499-
with pytest.raises(RuntimeError):
500-
modifier.scheduled_log_update(model, optimizer, 0.0, test_steps_per_epoch)
501-
502-
self.initialize_helper(modifier, model, log_initialize=False)
503-
504494
with pytest.raises(RuntimeError):
505495
modifier.scheduled_log_update(model, optimizer, 0.0, test_steps_per_epoch)
506496

0 commit comments

Comments
 (0)