Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 62b1938

Browse files
markurtzbfineran
andauthored
[cherry-pick] Add emulated_step for pytorch wrapper (#238)
* Add PyTorch emulated_step in manager wrapper for differing steps_per_epoch (#236) * Add PyTorch emulated_step in manager wrapper for differing steps_per_epoch * Add PyTorch emulated_step in manager wrapper for differing steps_per_epoch * update transfer learning notebook quant file name (#237) Co-authored-by: Benjamin Fineran <[email protected]>
1 parent 93e0415 commit 62b1938

File tree

2 files changed

+42
-23
lines changed

2 files changed

+42
-23
lines changed

examples/pytorch_sparse_quantized_transfer_learning/pytorch_sparse_quantized_transfer_learning.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -332,13 +332,13 @@
332332
"from sparseml.pytorch.utils import ModuleExporter\n",
333333
"\n",
334334
"save_dir = \"pytorch_sparse_quantized_transfer_learning\"\n",
335-
"qat_onnx_graph_name = \"resnet50_imagenette_pruned_qat.onnx\"\n",
336-
"quantized_onnx_path = os.path.join(save_dir, \"resnet50_imagenette_pruned_quant.onnx\")\n",
335+
"quant_onnx_graph_name = \"resnet50_imagenette_pruned_quant.onnx\"\n",
336+
"quantized_onnx_path = os.path.join(save_dir, quant_onnx_graph_name)\n",
337337
"\n",
338338
"exporter = ModuleExporter(model, output_dir=save_dir)\n",
339339
"exporter.export_pytorch(name=\"resnet50_imagenette_pruned_qat.pth\")\n",
340340
"exporter.export_onnx(\n",
341-
" torch.randn(1, 3, 224, 224), name=qat_onnx_graph_name, convert_qat=True\n",
341+
" torch.randn(1, 3, 224, 224), name=quant_onnx_graph_name, convert_qat=True\n",
342342
")\n",
343343
"\n",
344344
"print(f\"Sparse-Quantized ONNX model saved to {quantized_onnx_path}\")"

src/sparseml/pytorch/optim/manager.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,43 @@ def step(self, *args, **kwargs):
158158
:param kwargs: Any kwargs to pass to the wrapped objects step function.
159159
:return: The return, if any, from the wrapped objects step function
160160
"""
161+
return self._perform_wrapped_step(*args, **kwargs)
162+
163+
def emulated_step(self):
164+
"""
165+
Emulated step function to be called in place of step when the
166+
number of steps_per_epoch vary across epochs.
167+
The emulated function should be called to keep the steps_per_epoch thee same.
168+
Does not call into the step function for the wrapped object,
169+
but does call into the manager to increment the steps.
170+
"""
171+
self._perform_wrapped_step(skip_orig_step=True)
172+
173+
def loss_update(self, loss: Tensor) -> Tensor:
174+
"""
175+
Optional call to update modifiers based on the calculated loss.
176+
Not needed unless one or more of the modifier is using the loss
177+
to make a modification or is modifying the loss itself.
178+
179+
:param loss: the calculated loss after running a forward pass and loss_fn
180+
:return: the modified loss tensor
181+
"""
182+
loss = self._wrapped_manager.loss_update(
183+
loss,
184+
self._wrapped_module,
185+
self._wrapped_optimizer,
186+
self._wrapped_epoch,
187+
self._wrapped_steps_per_epoch,
188+
)
189+
190+
return loss
191+
192+
def _perform_wrapped_step(self, *args, **kwargs) -> Any:
193+
skip_orig_step = (
194+
kwargs["skip_orig_step"] if "skip_orig_step" in kwargs else False
195+
)
196+
ret = None
197+
161198
if self._wrapped_manager.enabled:
162199
self._wrapped_manager.update(
163200
self._wrapped_module,
@@ -172,7 +209,8 @@ def step(self, *args, **kwargs):
172209
self._wrapped_steps_per_epoch,
173210
)
174211

175-
ret = self._wrapped.step(*args, **kwargs)
212+
if not skip_orig_step:
213+
ret = self._wrapped.step(*args, **kwargs)
176214

177215
if self._wrapped_manager.enabled:
178216
self._wrapped_manager.optimizer_post_step(
@@ -192,25 +230,6 @@ def step(self, *args, **kwargs):
192230

193231
return ret
194232

195-
def loss_update(self, loss: Tensor) -> Tensor:
196-
"""
197-
Optional call to update modifiers based on the calculated loss.
198-
Not needed unless one or more of the modifier is using the loss
199-
to make a modification or is modifying the loss itself.
200-
201-
:param loss: the calculated loss after running a forward pass and loss_fn
202-
:return: the modified loss tensor
203-
"""
204-
loss = self._wrapped_manager.loss_update(
205-
loss,
206-
self._wrapped_module,
207-
self._wrapped_optimizer,
208-
self._wrapped_epoch,
209-
self._wrapped_steps_per_epoch,
210-
)
211-
212-
return loss
213-
214233

215234
class ScheduledModifierManager(BaseManager, Modifier):
216235
"""

0 commit comments

Comments
 (0)