Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit dce81ea

Browse files
authored
initialize M-FAC grad buffer based on masks, not param values (#287)
* initialize M-FAC grad buffer based on masks, not param values * update tests
1 parent 44be313 commit dce81ea

File tree

3 files changed

+17
-12
lines changed

3 files changed

+17
-12
lines changed

src/sparseml/pytorch/optim/mask_pruning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ def pre_optim_step_update(self):
421421
updates scores and buffers that depend on gradients. Should be called
422422
before Optimizer.step() to grab the latest gradients
423423
"""
424-
self._scorer.pre_optim_step_update()
424+
self._scorer.pre_optim_step_update(self._param_masks)
425425

426426
def pruning_end(self, leave_enabled: bool):
427427
"""

src/sparseml/pytorch/optim/mask_pruning_scorer.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,12 @@ def score_parameters(self) -> List[Tensor]:
5858
"""
5959
raise NotImplementedError()
6060

61-
def pre_optim_step_update(self):
61+
def pre_optim_step_update(self, masks: List[Tensor]):
6262
"""
6363
Perform any required logic for tracking Parameter data and gradients before
6464
an Optimizer step is applied to the model.
65+
66+
:param masks: latest masks that are applied to these parameters
6567
"""
6668
pass
6769

@@ -226,9 +228,11 @@ def score_parameters(self) -> List[Tensor]:
226228

227229
return self._movement_scores
228230

229-
def pre_optim_step_update(self):
231+
def pre_optim_step_update(self, masks: List[Tensor]):
230232
"""
231233
Update movement scores based on the current Parameter weights and gradients
234+
235+
:param masks: latest masks that are applied to these parameters
232236
"""
233237
self.check_regen_param_vals()
234238
for idx, param in enumerate(self._params):
@@ -374,17 +378,19 @@ def score_parameters(self) -> List[Tensor]:
374378

375379
return param_scores
376380

377-
def pre_optim_step_update(self):
381+
def pre_optim_step_update(self, masks: List[Tensor]):
378382
"""
379383
Update the gradient buffer based on the current gradients
384+
385+
:param masks: latest masks that are applied to these parameters
380386
"""
381387

382388
if any(param.grad is None for param in self._params):
383389
# only update buffer if all gradients are computed
384390
return
385391

386392
if self._grad_buffer is None:
387-
self._setup_grad_buffer()
393+
self._setup_grad_buffer(masks)
388394

389395
# get non-pruned grads
390396
non_pruned_grads = [
@@ -432,7 +438,7 @@ def mask_update(self, masks: List[Tensor], mask_diffs: List[Tensor]):
432438

433439
self._latest_h_inv_diag = None # clear h_inv
434440
self._grads = None # clear grads
435-
self._setup_grad_buffer() # reset grad buffer
441+
self._setup_grad_buffer(masks) # reset grad buffer
436442
torch.cuda.empty_cache() # release GPU memory
437443

438444
@staticmethod
@@ -509,12 +515,10 @@ def _calc_params_perterb(self, mask_diffs):
509515
h_inv, diag = self._latest_h_inv_diag
510516
return h_inv.mul(-1.0 * weights_to_prune / diag)
511517

512-
def _setup_grad_buffer(self):
518+
def _setup_grad_buffer(self, masks: Tensor):
513519
total_nonzero = 0
514-
for idx, param in enumerate(self._params):
515-
self._unpruned_idxs[idx] = (
516-
param.view(-1).nonzero(as_tuple=False).reshape(-1)
517-
)
520+
for idx, mask in enumerate(masks):
521+
self._unpruned_idxs[idx] = mask.view(-1).nonzero(as_tuple=False).reshape(-1)
518522
total_nonzero += self._unpruned_idxs[idx].numel()
519523
# only track nonzero grads
520524
num_grads = self._mfac_options.get_num_grads_for_sparsity(

tests/sparseml/pytorch/optim/test_mask_pruning_scorer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def test_pruning_scorer(score_type, n_updates):
5252

5353
for i in range(n_updates):
5454
_fake_params_random_update(params)
55-
scorer.pre_optim_step_update()
55+
fake_masks = [(param != 0).type(param.dtype) for param in params]
56+
scorer.pre_optim_step_update(fake_masks)
5657
scores = scorer.score_parameters()
5758
assert len(scores) == len(params)
5859

0 commit comments

Comments
 (0)