This repository was archived by the owner on Jun 3, 2025. It is now read-only.
File tree Expand file tree Collapse file tree 1 file changed +10
-4
lines changed
src/sparseml/pytorch/optim Expand file tree Collapse file tree 1 file changed +10
-4
lines changed Original file line number Diff line number Diff line change @@ -332,6 +332,16 @@ def score_parameters(self) -> List[Tensor]:
332332 given by the OBS method. For the approximated Hessian inverse matrix
333333 H^-1, scores will be W^2 / (2 * diag(H^-1))
334334 """
335+
336+ if self ._grad_buffer is None or torch .any (
337+ torch .all (self ._grad_buffer == 0.0 , dim = 1 )
338+ ):
339+ # raise Exception if grad buffer is not full
340+ raise RuntimeError (
341+ "MFAC pruning step called, but not enough gradient samples have been "
342+ f"collected. Expected { self ._mfac_options .num_grads } samples"
343+ )
344+
335345 if self ._is_ddp :
336346 # move all grads to one device
337347 if self ._is_main_proc :
@@ -450,10 +460,6 @@ def get_name() -> str:
450460
451461 def _score_parameters (self ) -> List [Tensor ]:
452462 # score params using MFAC and the gathered grad buffers
453- if torch .any (torch .all (self ._grads == 0.0 , dim = 1 )):
454- # if not all grads are captured, return magnitudes as scores
455- return [torch .abs (param .data ) for param in self ._params ]
456-
457463 # gather non-pruned weights
458464 non_pruned_weights = torch .empty (self ._grads .size (1 )).to (self ._grads .device )
459465 weights_idx = 0
You can’t perform that action at this time.
0 commit comments