Skip to content

Commit ad74d32

Browse files
committed
Update apply_quantiation_config to use match_named_modules
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
1 parent 3d49764 commit ad74d32

File tree

2 files changed

+82
-46
lines changed

2 files changed

+82
-46
lines changed

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 29 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
is_kv_cache_quant_scheme,
4141
)
4242
from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module
43+
from compressed_tensors.utils.match import match_named_modules
4344
from compressed_tensors.utils.offload import update_parameter_data
4445
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
4546
from safetensors import safe_open
@@ -147,47 +148,35 @@ def apply_quantization_config(
147148
if run_compressed:
148149
from compressed_tensors.linear.compressed_linear import CompressedLinear
149150

150-
# list of submodules to ignore
151-
ignored_submodules = defaultdict(list)
152151
# mark appropriate layers for quantization by setting their quantization schemes
153-
for name, submodule in model.named_modules():
154-
# potentially fix module name to remove FSDP wrapper prefix
155-
name = fix_fsdp_module_name(name)
156-
if matches := find_name_or_class_matches(name, submodule, config.ignore):
157-
for match in matches:
158-
ignored_submodules[match].append(name)
159-
continue # layer matches ignore list, continue
160-
161-
targets = find_name_or_class_matches(name, submodule, target_to_scheme)
162-
163-
if targets:
164-
# mark modules to be quantized by adding
165-
# quant scheme to the matching layers
166-
scheme = _scheme_from_targets(target_to_scheme, targets, name)
167-
if run_compressed:
168-
format = config.format
169-
if format != CompressionFormat.dense.value:
170-
if isinstance(submodule, torch.nn.Linear):
171-
# TODO: expand to more module types
172-
compressed_linear = CompressedLinear.from_linear(
173-
submodule,
174-
quantization_scheme=scheme,
175-
quantization_format=format,
176-
)
177-
replace_module(model, name, compressed_linear)
178-
179-
# target matched - add layer and scheme to target list
180-
submodule.quantization_scheme = scheme
181-
182-
names_to_scheme[name] = submodule.quantization_scheme
183-
184-
if config.ignore is not None and ignored_submodules is not None:
185-
if set(config.ignore) - set(ignored_submodules):
186-
_LOGGER.warning(
187-
"Some layers that were to be ignored were "
188-
"not found in the model: "
189-
f"{set(config.ignore) - set(ignored_submodules)}"
190-
)
152+
for name, submodule, matched_targets in match_named_modules(
153+
model,
154+
target_to_scheme,
155+
config.ignore or [],
156+
warn_on_fail=True,
157+
warn_on_unmatched_ignores=True,
158+
return_matched_targets=True,
159+
preprocess_name=fix_fsdp_module_name,
160+
):
161+
# mark modules to be quantized by adding
162+
# quant scheme to the matching layers
163+
scheme = _scheme_from_targets(target_to_scheme, matched_targets, name)
164+
if run_compressed:
165+
format = config.format
166+
if format != CompressionFormat.dense.value:
167+
if isinstance(submodule, torch.nn.Linear):
168+
# TODO: expand to more module types
169+
compressed_linear = CompressedLinear.from_linear(
170+
submodule,
171+
quantization_scheme=scheme,
172+
quantization_format=format,
173+
)
174+
replace_module(model, name, compressed_linear)
175+
176+
# target matched - add layer and scheme to target list
177+
submodule.quantization_scheme = scheme
178+
179+
names_to_scheme[name] = submodule.quantization_scheme
191180

192181
# apply current quantization status across all targeted layers
193182
apply_quantization_status(model, config.quantization_status)
@@ -429,7 +418,6 @@ def _scheme_from_targets(
429418
def _merge_schemes(
430419
schemes_to_merge: List[QuantizationScheme], name: str
431420
) -> QuantizationScheme:
432-
433421
kv_cache_quantization_scheme = [
434422
scheme for scheme in schemes_to_merge if is_kv_cache_quant_scheme(scheme)
435423
]

src/compressed_tensors/utils/match.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import logging
1616
import re
1717
from collections.abc import Generator
18-
from typing import Iterable, Tuple
18+
from typing import Callable, Iterable, Tuple
1919

2020
import torch
2121
from compressed_tensors.utils.internal import InternalModule
@@ -35,8 +35,11 @@
3535
def match_named_modules(
3636
model: torch.nn.Module,
3737
targets: Iterable[str],
38-
ignore: Iterable[str] = tuple(),
38+
ignore: Iterable[str] | None = tuple(),
3939
warn_on_fail: bool = False,
40+
warn_on_unmatched_ignores: bool = False,
41+
return_matched_targets: bool = False,
42+
preprocess_name: Callable[[str], str] = lambda x: x,
4043
) -> Generator[Tuple[str, torch.nn.Module]]:
4144
"""
4245
Yields names and modules which match `targets` but do not match `ignore`.
@@ -48,21 +51,66 @@ def match_named_modules(
4851
:param warn_on_fail: if True, warns if any targets do not match any modules in model
4952
:return: generator of module names and modules
5053
"""
54+
ignore = ignore or []
55+
5156
unmatched_targets = set(targets)
57+
unmatched_ignores = set(ignore)
58+
59+
# Order targets by type: exact name match, regex name match, class name match
60+
targets = sorted(targets, key=lambda x: ("re:" in x, x))
5261
for name, module in model.named_modules():
62+
# preprocess the module name and module
63+
name = preprocess_name(name)
64+
65+
ignore_matched = False
66+
for ign in ignore:
67+
if is_match(name, module, ign):
68+
unmatched_ignores -= {ign}
69+
ignore_matched = True
70+
break
71+
if ignore_matched:
72+
continue
73+
74+
matched_targets = []
75+
# Check for name matches first (exact then regex)
5376
for target in targets:
54-
if is_match(name, module, target):
77+
if match_name(name, target):
5578
unmatched_targets -= {target}
79+
matched_targets.append(target)
80+
if not return_matched_targets:
81+
break
5682

57-
if not any(is_match(name, module, ign) for ign in ignore):
58-
yield name, module
83+
if not return_matched_targets and matched_targets:
84+
# Don't need to check other targets, one match is enough
85+
yield name, module
86+
continue
87+
88+
# Check for class matches
89+
for target in targets:
90+
if match_class(module, target):
91+
unmatched_targets -= {target}
92+
matched_targets.append(target)
93+
if not return_matched_targets:
94+
break
95+
96+
if matched_targets:
97+
if return_matched_targets:
98+
yield name, module, matched_targets
99+
else:
100+
yield name, module
59101

60102
if warn_on_fail:
61103
for target in unmatched_targets:
62104
_LOGGER.warning(
63105
f"Could not match `{target}` in instance of {model.__class__.__name__}"
64106
)
65107

108+
if warn_on_unmatched_ignores:
109+
for ign in unmatched_ignores:
110+
_LOGGER.warning(
111+
f"Unmatched ignore targets: {unmatched_ignores}, in instance of {model.__class__.__name__}"
112+
)
113+
66114

67115
def match_named_parameters(
68116
model: torch.nn.Module,

0 commit comments

Comments
 (0)