@@ -58,10 +58,12 @@ def score_parameters(self) -> List[Tensor]:
5858 """
5959 raise NotImplementedError ()
6060
61- def pre_optim_step_update (self ):
61+ def pre_optim_step_update (self , masks : List [ Tensor ] ):
6262 """
6363 Perform any required logic for tracking Parameter data and gradients before
6464 an Optimizer step is applied to the model.
65+
66+ :param masks: latest masks that are applied to these parameters
6567 """
6668 pass
6769
@@ -226,9 +228,11 @@ def score_parameters(self) -> List[Tensor]:
226228
227229 return self ._movement_scores
228230
229- def pre_optim_step_update (self ):
231+ def pre_optim_step_update (self , masks : List [ Tensor ] ):
230232 """
231233 Update movement scores based on the current Parameter weights and gradients
234+
235+ :param masks: latest masks that are applied to these parameters
232236 """
233237 self .check_regen_param_vals ()
234238 for idx , param in enumerate (self ._params ):
@@ -374,17 +378,19 @@ def score_parameters(self) -> List[Tensor]:
374378
375379 return param_scores
376380
377- def pre_optim_step_update (self ):
381+ def pre_optim_step_update (self , masks : List [ Tensor ] ):
378382 """
379383 Update the gradient buffer based on the current gradients
384+
385+ :param masks: latest masks that are applied to these parameters
380386 """
381387
382388 if any (param .grad is None for param in self ._params ):
383389 # only update buffer if all gradients are computed
384390 return
385391
386392 if self ._grad_buffer is None :
387- self ._setup_grad_buffer ()
393+ self ._setup_grad_buffer (masks )
388394
389395 # get non-pruned grads
390396 non_pruned_grads = [
@@ -432,7 +438,7 @@ def mask_update(self, masks: List[Tensor], mask_diffs: List[Tensor]):
432438
433439 self ._latest_h_inv_diag = None # clear h_inv
434440 self ._grads = None # clear grads
435- self ._setup_grad_buffer () # reset grad buffer
441+ self ._setup_grad_buffer (masks ) # reset grad buffer
436442 torch .cuda .empty_cache () # release GPU memory
437443
438444 @staticmethod
@@ -509,12 +515,10 @@ def _calc_params_perterb(self, mask_diffs):
509515 h_inv , diag = self ._latest_h_inv_diag
510516 return h_inv .mul (- 1.0 * weights_to_prune / diag )
511517
512- def _setup_grad_buffer (self ):
518+ def _setup_grad_buffer (self , masks : Tensor ):
513519 total_nonzero = 0
514- for idx , param in enumerate (self ._params ):
515- self ._unpruned_idxs [idx ] = (
516- param .view (- 1 ).nonzero (as_tuple = False ).reshape (- 1 )
517- )
520+ for idx , mask in enumerate (masks ):
521+ self ._unpruned_idxs [idx ] = mask .view (- 1 ).nonzero (as_tuple = False ).reshape (- 1 )
518522 total_nonzero += self ._unpruned_idxs [idx ].numel ()
519523 # only track nonzero grads
520524 num_grads = self ._mfac_options .get_num_grads_for_sparsity (
0 commit comments