Skip to content

Commit 06a1e71

Browse files
committed
get_layers_params refactor
1 parent 63cb6e6 commit 06a1e71

File tree

4 files changed

+4
-45
lines changed

4 files changed

+4
-45
lines changed

src/llmcompressor/modifiers/pruning/constant/base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
LayerParamMasking,
99
param_mask_name,
1010
)
11-
from llmcompressor.utils.pytorch.module import get_layers_params
12-
11+
from compressed_tensors import match_named_parameters
1312
__all__ = ["ConstantPruningModifier"]
1413

1514

@@ -29,7 +28,7 @@ def on_initialize(self, state: State, **kwargs) -> bool:
2928
if not state.model:
3029
return False
3130

32-
self.parameterized_layers_ = get_layers_params(self.targets, state.model)
31+
self.parameterized_layers_ = match_named_parameters(self.targets, state.model)
3332

3433
for layer_param_name, parameterized_layer in self.parameterized_layers_.items():
3534
self.add_mask(

src/llmcompressor/modifiers/pruning/magnitude/base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
PruningMaskCreatorArgs,
1717
PruningMaskFactory,
1818
)
19-
from llmcompressor.utils.pytorch.module import get_layers_params
20-
19+
from compressed_tensors import match_named_parameters
2120
__all__ = ["MagnitudePruningModifier"]
2221

2322

@@ -73,7 +72,7 @@ def on_initialize(self, state: State, **kwargs) -> bool:
7372
self.mask_structure
7473
)
7574

76-
self.parameterized_layers_ = get_layers_params(state.model)
75+
self.parameterized_layers_ = match_named_parameters(state.model)
7776

7877
for layer_param_name, parameterized_layer in self.parameterized_layers_.items():
7978
self.add_mask(

src/llmcompressor/utils/pytorch/module.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,8 @@
5858
"get_prunable_layers",
5959
"get_quantizable_layers",
6060
"qat_active",
61-
"get_layers_params",
6261
"get_matching_layer",
6362
"get_no_split_params",
64-
"get_layer_by_name",
6563
]
6664

6765
ALL_TARGET = "__ALL__"
@@ -292,22 +290,6 @@ def qat_active(module: Module) -> bool:
292290
return False
293291

294292

295-
def get_layers_params(
296-
targets: Union[str, List[str]], module: Module
297-
) -> Dict[str, ModelParameterizedLayer]:
298-
params = get_params(targets, module)
299-
layers = get_layers(targets, module)
300-
301-
parameterized_layers = {}
302-
for name, param in params.items():
303-
param_layer = ModelParameterizedLayer(
304-
layer_name=name, layer=layers[name], param_name=name, param=param
305-
)
306-
parameterized_layers[name] = param_layer
307-
308-
return parameterized_layers
309-
310-
311293
def get_matching_layer(
312294
target: str, name_to_match: str, module: Module
313295
) -> Optional[Tuple[str, Module]]:
Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
import pytest
22
import torch.nn as nn
33

4-
from llmcompressor.utils.pytorch import get_layer_by_name
5-
6-
74
@pytest.fixture
85
def example_nested_module() -> str:
96
return nn.Sequential(
@@ -14,21 +11,3 @@ def example_nested_module() -> str:
1411
)
1512

1613

17-
@pytest.mark.unit
18-
def test_get_layer_by_name(example_nested_module):
19-
# Test getting the parent of a nested layer
20-
layer = get_layer_by_name("0", example_nested_module)
21-
assert layer == example_nested_module[0]
22-
23-
layer = get_layer_by_name("1.1", example_nested_module)
24-
assert layer == example_nested_module[1][1]
25-
26-
layer = get_layer_by_name("2.0", example_nested_module)
27-
assert layer == example_nested_module[2][0]
28-
29-
layer = get_layer_by_name("2.1", example_nested_module)
30-
assert layer == example_nested_module[2][1]
31-
32-
# Test getting the parent of a non-existent layer
33-
with pytest.raises(AttributeError):
34-
get_layer_by_name("non_existent_layer", example_nested_module)

0 commit comments

Comments
 (0)