Skip to content

Commit d60e95f

Browse files
committed
fix pruning hasattr
1 parent cd30ce4 commit d60e95f

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/lightning/pytorch/callbacks/pruning.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,8 @@ def make_pruning_permanent(self, module: nn.Module) -> None:
277277

278278
@staticmethod
279279
def _copy_param(new: nn.Module, old: nn.Module, name: str) -> None:
280-
dst = getattr(new, name)
280+
# Check if the parameter has been pruned (has _orig suffix)
281+
dst = getattr(new, name + "_orig") if hasattr(new, name + "_orig") else getattr(new, name)
281282
src = getattr(old, name)
282283
if dst is None or src is None or not isinstance(dst, Tensor) or not isinstance(src, Tensor):
283284
return

0 commit comments

Comments
 (0)