Skip to content

Commit d495843

Browse files
committed
removed get_layer, get_layers -> match_named_modules
1 parent ba045bc commit d495843

File tree

6 files changed

+15
-57
lines changed

6 files changed

+15
-57
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from llmcompressor.pipelines.cache import IntermediatesCache
3131
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
3232
from llmcompressor.utils.helpers import calibration_forward_context
33-
from llmcompressor.utils.pytorch.module import get_layers
33+
from compressed_tensors import match_named_modules
3434

3535
__all__ = ["AWQModifier"]
3636

@@ -305,7 +305,7 @@ def _set_resolved_mappings(self, model: Module) -> None:
305305
"""
306306
resolved_mappings: list[ResolvedMapping] = []
307307
for mapping_idx, mapping in enumerate(self.mappings):
308-
smooth_layers = get_layers(
308+
smooth_layers = match_named_modules(
309309
mapping.smooth_layer, model, exclude_internal_modules=True
310310
)
311311
smooth_names = [
@@ -329,7 +329,7 @@ def _set_resolved_mappings(self, model: Module) -> None:
329329
balance_layers, balance_names = [], []
330330
for balance_regex in mapping.balance_layers:
331331
# find the submodules that match the activation layer
332-
for balance_suffix, balance_layer in get_layers(
332+
for balance_suffix, balance_layer in match_named_modules(
333333
balance_regex,
334334
smooth_parent,
335335
exclude_internal_modules=True,

src/llmcompressor/modifiers/distillation/output/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
)
1212
from llmcompressor.utils.fsdp.context import summon_full_params_context
1313
from llmcompressor.utils.fsdp.helpers import maybe_get_wrapped, set_wrapped_model
14-
from llmcompressor.utils.pytorch.module import get_layers
14+
from compressed_tensors import match_named_modules
1515

1616
__all__ = ["OutputDistillationModifier"]
1717

@@ -61,8 +61,8 @@ def on_initialize(self, state: State, **kwargs) -> bool:
6161
else:
6262
model_target, teacher_target = target, target
6363

64-
model_layers = get_layers(model_target, state.model)
65-
teacher_layers = get_layers(teacher_target, state.teacher_model)
64+
model_layers = match_named_modules(model_target, state.model)
65+
teacher_layers = match_named_modules(teacher_target, state.teacher_model)
6666

6767
if len(model_layers) < 1:
6868
raise ValueError(f"no model layers found for target {target}")

src/llmcompressor/modifiers/obcq/sgpt_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
from llmcompressor.modifiers.modifier import Modifier
1414
from llmcompressor.modifiers.utils.hooks import HooksMixin
1515
from llmcompressor.utils.pytorch.module import (
16-
get_layers,
1716
get_no_split_params,
1817
get_prunable_layers,
1918
match_targets,
2019
)
20+
from compressed_tensors import match_named_modules
2121

2222

2323
class SparsityModifierBase(Modifier):
@@ -114,8 +114,8 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
114114

115115
# infer module and sequential targets
116116
self.sequential_targets = self._infer_sequential_targets(model)
117-
layers = get_layers(self.sequential_targets, model)
118-
self._target_layers = get_layers(
117+
layers = match_named_modules(self.sequential_targets, model)
118+
self._target_layers = match_named_modules(
119119
self.targets, model
120120
) # layers containing targets
121121

src/llmcompressor/modifiers/smoothquant/base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@
1515
)
1616
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
1717
from llmcompressor.utils.pytorch.module import (
18-
get_layers,
1918
get_matching_layer,
2019
match_targets,
2120
)
22-
21+
from compressed_tensors import match_named_modules
2322
MINIMUM_SMOOTHING_SCALE = 1e-5
2423

2524

@@ -204,7 +203,7 @@ def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]:
204203
"""
205204
resolved_mappings = []
206205
for to_balance, to_smooth in self.mappings:
207-
to_smooth_layers = get_layers(to_smooth, model)
206+
to_smooth_layers = match_named_modules(to_smooth, model)
208207
for layer_name, smooth_layer in to_smooth_layers.items():
209208
if not match_targets(layer_name, self.ignore)[0]:
210209
balance_layers = []

src/llmcompressor/utils/pytorch/module.py

Lines changed: 2 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
fix_fsdp_module_name,
2020
summon_full_params_context,
2121
)
22+
from compressed_tensors import match_named_modules
2223

2324
try:
2425
quant_err = None
@@ -49,8 +50,6 @@
4950
"match_targets",
5051
"get_default_params",
5152
"match_layers_params",
52-
"get_layers",
53-
"get_layer",
5453
"get_terminal_layers",
5554
"get_prunable_layers",
5655
"get_quantizable_layers",
@@ -155,45 +154,6 @@ def match_layers_params(
155154
return resolved
156155

157156

158-
def get_layers(
159-
targets: Union[str, List[str]],
160-
module: Module,
161-
exclude_internal_modules: bool = False,
162-
) -> Dict[str, Module]:
163-
"""
164-
Get layers (also known as submodules) of module based on targets
165-
166-
:param targets: names or regexes to search for
167-
Can be regex, e.g. "re:.*input_layernorm$" to find all layers
168-
in module whose names end in string "input_layernorm"
169-
:param module: Parent module in which to search for targets
170-
:param exclude_internal_modules: If True, don't include internal
171-
modules added by llm-compressor, e.g. Observers and Transforms.
172-
Defaults to False to maintain backward compatibility
173-
174-
:return: dict of {layer name -> module} of all layers in module
175-
that match targets
176-
"""
177-
layer_dict = match_layers_params(targets, module)
178-
if exclude_internal_modules:
179-
layer_dict = {
180-
name: layer
181-
for name, layer in layer_dict.items()
182-
if not isinstance(layer, InternalModule)
183-
}
184-
185-
return layer_dict
186-
187-
188-
def get_layer(target: str, module: Module) -> Tuple[str, Module]:
189-
layers = get_layers(target, module)
190-
if len(layers) != 1:
191-
raise ValueError(f"Expected 1 layer for target {target}, found {len(layers)}")
192-
name, layer = next(iter(layers.items()))
193-
194-
return name, layer
195-
196-
197157
def get_terminal_layers(module: Module) -> Dict[str, Module]:
198158
terminal = {}
199159

@@ -271,7 +231,7 @@ def get_matching_layer(
271231
:return: Tuple containing the layer name and module that fits the target regex and
272232
best matches name_to_match, or None if no match can be found
273233
"""
274-
potential_matches = get_layers(target, module)
234+
potential_matches = match_named_modules(target, module)
275235
largest_substring = 0
276236
match = None
277237
for name, module in potential_matches.items():

tests/llmcompressor/transformers/obcq/test_obcq_owl.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
from llmcompressor.core.session_functions import create_session
77
from llmcompressor.datasets import format_calibration_data
88
from llmcompressor.modifiers.obcq import SparseGPTModifier
9-
from llmcompressor.utils.pytorch.module import get_layers
10-
9+
from compressed_tensors import match_named_modules
1110

1211
@pytest.mark.integration
1312
def test_infer_owl_layer_sparsity():
@@ -29,7 +28,7 @@ def test_infer_owl_layer_sparsity():
2928
dataloader = format_calibration_data(dataset)
3029

3130
sequential_targets = modifier._infer_sequential_targets(model)
32-
layers = get_layers(sequential_targets, model)
31+
layers = match_named_modules(sequential_targets, model)
3332
sparsities = modifier._infer_owl_layer_sparsity(model, layers, dataloader)
3433
assert sparsities.keys() == layers.keys()
3534

0 commit comments

Comments
 (0)