Skip to content

Commit 2bba395

Browse files
committed
fix pruning logging calculation
1 parent e088694 commit 2bba395

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/lightning/pytorch/callbacks/pruning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def apply_pruning(self, amount: Union[int, float]) -> None:
349349
def _log_sparsity_stats(
350350
self, prev: list[tuple[int, int]], curr: list[tuple[int, int]], amount: Union[int, float] = 0
351351
) -> None:
352-
total_params = sum(p.numel() for layer, _ in self._parameters_to_prune for p in layer.parameters())
352+
total_params = sum(total for _, total in curr)
353353
prev_total_zeros = sum(zeros for zeros, _ in prev)
354354
curr_total_zeros = sum(zeros for zeros, _ in curr)
355355
log.info(

0 commit comments

Comments
 (0)