Skip to content

Commit d22149f

Browse files
committed
Simplify signature of match_named_modules
Removed `yield_matched_targets` and `warn_on_unmatched_ignores` and updated rest of code Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
1 parent 01e75b2 commit d22149f

File tree

3 files changed

+28
-56
lines changed

3 files changed

+28
-56
lines changed

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
is_kv_cache_quant_scheme,
4040
)
4141
from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module
42-
from compressed_tensors.utils.match import match_named_modules
42+
from compressed_tensors.utils.match import is_match, match_named_modules, match_targets
4343
from compressed_tensors.utils.offload import update_parameter_data
4444
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
4545
from safetensors import safe_open
@@ -145,17 +145,16 @@ def apply_quantization_config(
145145
from compressed_tensors.linear.compressed_linear import CompressedLinear
146146

147147
# mark appropriate layers for quantization by setting their quantization schemes
148-
for name, submodule, matched_targets in match_named_modules(
148+
for name, submodule in match_named_modules(
149149
model,
150150
target_to_scheme,
151151
config.ignore or [],
152152
warn_on_fail=True,
153-
warn_on_unmatched_ignores=True,
154-
yield_matched_targets=True,
155153
preprocess_name=fix_fsdp_module_name,
156154
):
157155
# mark modules to be quantized by adding
158156
# quant scheme to the matching layers
157+
matched_targets = list(match_targets(name, submodule, target_to_scheme))
159158
scheme = _scheme_from_targets(target_to_scheme, matched_targets, name)
160159
if run_compressed:
161160
format = config.format

src/compressed_tensors/utils/match.py

Lines changed: 20 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ def match_named_modules(
3737
targets: Iterable[str] | None,
3838
ignore: Iterable[str] | None = None,
3939
warn_on_fail: bool = False,
40-
warn_on_unmatched_ignores: bool = False,
41-
yield_matched_targets: bool = False,
4240
preprocess_name: Callable[[str], str] = lambda x: x,
4341
) -> Generator[Tuple[str, torch.nn.Module] | Tuple[str, torch.nn.Module, List[str]]]:
4442
"""
@@ -49,70 +47,36 @@ def match_named_modules(
4947
:param targets: target strings, potentially containing "re:" prefixes
5048
:param ignore: targets to ignore, potentially containing "re:" prefixes
5149
:param warn_on_fail: if True, warns if any targets do not match any modules in model
52-
:param warn_on_unmatched_ignores: if True, warns if any ignores do not match any modules in model
53-
:param yield_matched_targets: if True, yields the matched targets in addition to the module name and module
5450
:param preprocess_name: a function to preprocess the module name
5551
:return: generator of module names and modules
5652
"""
5753
ignore = ignore or []
5854
targets = targets or []
5955

6056
unmatched_targets = set(targets)
61-
unmatched_ignores = set(ignore)
6257

63-
# Note: when yield_matched_targets is True, the ordering of the targets is important
64-
# Order targets by type: exact name match, regex name match, class name match
65-
targets = sorted(targets, key=lambda x: ("re:" in x, x))
6658
for name, module in model.named_modules():
6759
if isinstance(module, InternalModule):
6860
continue
6961

7062
# preprocess the module name and module
7163
name = preprocess_name(name)
7264

73-
ignore_matched = False
74-
for ign in ignore:
75-
if is_match(name, module, ign):
76-
unmatched_ignores -= {ign}
77-
ignore_matched = True
78-
break
79-
if ignore_matched:
65+
if any(is_match(name, module, ign) for ign in ignore):
8066
continue
8167

82-
matched_target_on_name = []
83-
matched_target_on_class = []
84-
# Check for name matches first (exact then regex, enforced by sort above)
8568
for target in targets:
86-
if _match_name(name, target):
69+
if is_match(name, module, target):
8770
unmatched_targets -= {target}
88-
matched_target_on_name.append(target)
89-
if not yield_matched_targets:
90-
break
91-
elif _match_class(module, target):
92-
unmatched_targets -= {target}
93-
matched_target_on_class.append(target)
94-
if not yield_matched_targets:
95-
break
96-
97-
matched_targets = matched_target_on_name + matched_target_on_class
98-
if matched_targets:
99-
if yield_matched_targets:
100-
yield name, module, matched_targets
101-
else:
10271
yield name, module
72+
break
10373

10474
if warn_on_fail:
10575
for target in unmatched_targets:
10676
_LOGGER.warning(
10777
f"Could not match `{target}` in instance of {model.__class__.__name__}"
10878
)
10979

110-
if warn_on_unmatched_ignores:
111-
for ign in unmatched_ignores:
112-
_LOGGER.warning(
113-
f"Unmatched ignore targets: {unmatched_ignores}, in instance of {model.__class__.__name__}"
114-
)
115-
11680

11781
def match_named_parameters(
11882
model: torch.nn.Module,
@@ -151,6 +115,23 @@ def match_named_parameters(
151115
)
152116

153117

118+
def match_targets(
119+
name: str, module: torch.nn.Module, targets: Iterable[str]
120+
) -> Generator[str]:
121+
"""
122+
Yields the targets that match the given name and module.
123+
Outputs are ordered by type: exact name match, regex name match, class name match
124+
"""
125+
targets = sorted(targets, key=lambda x: ("re:" in x, x))
126+
for target in targets:
127+
if _match_name(name, target):
128+
yield target
129+
130+
for target in targets:
131+
if _match_class(module, target):
132+
yield target
133+
134+
154135
def match_modules_set(
155136
model: torch.nn.Module,
156137
targets: Iterable[str],

tests/test_quantization/lifecycle/test_apply.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -258,15 +258,13 @@ def get_sample_tinyllama_quant_config(status: str = "frozen"):
258258

259259
@requires_accelerate()
260260
@pytest.mark.parametrize(
261-
"ignore,should_raise_warning",
261+
"ignore",
262262
[
263-
[("lm_head", "re:.*gate"), False],
264-
[("lm_head", "re:.*foobarbaz"), True],
263+
("lm_head", "re:.*gate"),
264+
("lm_head", "re:.*foobarbaz"),
265265
],
266266
)
267-
def test_apply_quantization_status(caplog, ignore, should_raise_warning):
268-
import logging
269-
267+
def test_apply_quantization_status(ignore):
270268
# load a dense, unquantized tiny llama model
271269
model = get_tinyllama_model()
272270
quantization_config_dict = {
@@ -290,10 +288,4 @@ def test_apply_quantization_status(caplog, ignore, should_raise_warning):
290288
config = QuantizationConfig(**quantization_config_dict)
291289
config.quantization_status = QuantizationStatus.CALIBRATION
292290

293-
# mismatch in the ignore key of quantization_config_dict
294-
with caplog.at_level(logging.WARNING):
295-
apply_quantization_config(model, config)
296-
if should_raise_warning:
297-
assert len(caplog.text) > 0
298-
else:
299-
assert len(caplog.text) == 0
291+
apply_quantization_config(model, config)

0 commit comments

Comments
 (0)