Skip to content

Commit f012ae6

Browse files
committed
get_matching_layer -> match_modules_set, match_modules, get_linear_layers ->match_named_modules
1 parent 3394c8c commit f012ae6

File tree

5 files changed

+10
-11
lines changed

5 files changed

+10
-11
lines changed

src/llmcompressor/modifiers/smoothquant/base.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@
1414
handle_mapping_resolution_errors,
1515
)
1616
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
17-
from llmcompressor.utils.pytorch.module import (
18-
get_matching_layer,
19-
)
17+
from compressed_tensors import match_modules_set
2018
from compressed_tensors import match_named_modules
2119
MINIMUM_SMOOTHING_SCALE = 1e-5
2220

@@ -208,7 +206,7 @@ def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]:
208206
balance_layers = []
209207
for balance_suffix in to_balance:
210208
# find the submodule that matches the activation layer
211-
_, balance_layer = get_matching_layer(
209+
_, balance_layer =match_modules_set(
212210
balance_suffix, layer_name, model
213211
)
214212
if balance_layer:

src/llmcompressor/pipelines/layer_sequential/pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
from llmcompressor.pipelines.cache import IntermediatesCache
1313
from llmcompressor.pipelines.layer_sequential.helpers import (
1414
capture_first_layer_intermediates,
15-
match_modules,
1615
maybe_inject_pos_embeddings,
1716
to_next_layer_kwargs,
1817
)
18+
from compressed_tensors import match_named_modules
1919
from llmcompressor.pipelines.registry import CalibrationPipeline
2020
from llmcompressor.pipelines.sequential.helpers import (
2121
dispatch_for_sequential,
@@ -67,7 +67,7 @@ def __call__(
6767
# find layers
6868
modifiers = session.lifecycle.recipe.modifiers
6969
sequential_targets = get_sequential_targets(modifiers, model, dataset_args)
70-
layers = match_modules(model, sequential_targets)
70+
layers = match_named_modules(model, sequential_targets)
7171

7272
LifecycleCallbacks.calibration_epoch_start()
7373

src/llmcompressor/pipelines/sequential/helpers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from llmcompressor.modifiers.utils.hooks import HooksMixin
2525
from llmcompressor.utils.helpers import calibration_forward_context, patch_attr
2626
from llmcompressor.utils.pytorch.module import get_no_split_params
27+
from compressed_tensors import match_named_modules
2728

2829
from .ast_helpers import autowrap_forwards
2930

@@ -100,7 +101,7 @@ def trace_subgraphs(
100101
:return: a list of Subgraphs in order of execution
101102
"""
102103
# find modules
103-
targets = match_modules(model, sequential_targets)
104+
targets = match_named_modules(model, sequential_targets)
104105
ancestors = get_sequential_ancestors(model, targets)
105106
offloaded = set(m for m in model.modules() if has_offloaded_params(m))
106107

src/llmcompressor/transformers/compression/helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from tqdm import tqdm
99

1010
from llmcompressor.modifiers import Modifier
11-
from llmcompressor.pytorch.utils import get_linear_layers
1211
from llmcompressor.pytorch.utils.helpers import tensor_sparsity
12+
from compressed_tensors import match_named_modules
1313

1414
__ALL__ = [
1515
"tensor_follows_mask_structure",
@@ -76,7 +76,7 @@ def infer_sparsity_structure_from_model(model: torch.nn.Module) -> Optional[str]
7676
# check for the common sparsity structures
7777
structures = {"2:4"}
7878
for sparsity_structure in structures:
79-
linear_modules = get_linear_layers(model)
79+
linear_modules = match_named_modules(model, linear=True)
8080
offloaded_params = get_state_dict_offloaded_model(model)
8181

8282
linear_modules_with_sparsity_structure = [

tests/llmcompressor/transformers/tracing/test_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
WhisperForConditionalGeneration,
1515
)
1616

17-
from llmcompressor.pipelines.sequential.helpers import match_modules
1817
from llmcompressor.transformers.tracing.debug import trace
1918
from llmcompressor.utils.pytorch.module import get_no_split_params
19+
from compressed_tensors import match_named_modules
2020

2121

2222
@pytest.mark.skipif(
@@ -148,7 +148,7 @@ def get_target_modules(model, sequential_targets):
148148
if isinstance(sequential_targets, str):
149149
sequential_targets = [sequential_targets]
150150

151-
return match_modules(model, sequential_targets)
151+
return match_named_modules(model, sequential_targets)
152152

153153

154154
def run_subgraphs(model, subgraphs, inputs):

0 commit comments

Comments
 (0)