Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 97bf486

Browse files
authored
PyTorch modifier fixes (#239) (#240)
- allow block_shape to be more than two dimensions since on export it changes to weight shape - move initialized check under try catch in delete for safety
1 parent 62b1938 commit 97bf486

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

src/sparseml/pytorch/optim/mask_creator_pruning.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -474,13 +474,23 @@ def __init__(
474474
block_shape: List[int],
475475
grouping_fn_name: str = "mean",
476476
):
477-
if len(block_shape) != 2:
477+
if len(block_shape) < 2:
478478
raise ValueError(
479479
(
480-
"Invalid block_shape: {}"
481-
" ,block_shape must have length == 2 for in and out channels"
480+
"Invalid block_shape: {}, "
481+
"block_shape must have length == 2 for in and out channels"
482482
).format(block_shape)
483483
)
484+
485+
if len(block_shape) > 2 and not all([shape == 1 for shape in block_shape[2:]]):
486+
# after in and out channels, only 1 can be used for other dimensions
487+
raise ValueError(
488+
(
489+
"Invalid block_shape: {}, "
490+
"block_shape for indices not in [0, 1] must be equal to 1"
491+
).format(block_shape)
492+
)
493+
484494
self._block_shape = block_shape
485495
self._grouping_fn_name = grouping_fn_name
486496

src/sparseml/pytorch/optim/modifier.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,9 @@ def __init__(self, log_types: Union[str, List[str]] = None, **kwargs):
104104
self._loggers = None
105105

106106
def __del__(self):
107-
if not self.initialized:
108-
return
109-
110107
try:
108+
if not self.initialized:
109+
return
111110
self.finalize()
112111
except Exception:
113112
pass

0 commit comments

Comments
 (0)