Skip to content

Commit c5244da

Browse files
committed
Update apply_quantiation_config to use match_named_modules
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
1 parent 5478b43 commit c5244da

File tree

2 files changed

+82
-46
lines changed

2 files changed

+82
-46
lines changed

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 29 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
is_kv_cache_quant_scheme,
4141
)
4242
from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module
43+
from compressed_tensors.utils.match import match_named_modules
4344
from compressed_tensors.utils.offload import update_parameter_data
4445
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
4546
from safetensors import safe_open
@@ -147,47 +148,35 @@ def apply_quantization_config(
147148
if run_compressed:
148149
from compressed_tensors.linear.compressed_linear import CompressedLinear
149150

150-
# list of submodules to ignore
151-
ignored_submodules = defaultdict(list)
152151
# mark appropriate layers for quantization by setting their quantization schemes
153-
for name, submodule in model.named_modules():
154-
# potentially fix module name to remove FSDP wrapper prefix
155-
name = fix_fsdp_module_name(name)
156-
if matches := find_name_or_class_matches(name, submodule, config.ignore):
157-
for match in matches:
158-
ignored_submodules[match].append(name)
159-
continue # layer matches ignore list, continue
160-
161-
targets = find_name_or_class_matches(name, submodule, target_to_scheme)
162-
163-
if targets:
164-
# mark modules to be quantized by adding
165-
# quant scheme to the matching layers
166-
scheme = _scheme_from_targets(target_to_scheme, targets, name)
167-
if run_compressed:
168-
format = config.format
169-
if format != CompressionFormat.dense.value:
170-
if isinstance(submodule, torch.nn.Linear):
171-
# TODO: expand to more module types
172-
compressed_linear = CompressedLinear.from_linear(
173-
submodule,
174-
quantization_scheme=scheme,
175-
quantization_format=format,
176-
)
177-
replace_module(model, name, compressed_linear)
178-
179-
# target matched - add layer and scheme to target list
180-
submodule.quantization_scheme = scheme
181-
182-
names_to_scheme[name] = submodule.quantization_scheme
183-
184-
if config.ignore is not None and ignored_submodules is not None:
185-
if set(config.ignore) - set(ignored_submodules):
186-
_LOGGER.warning(
187-
"Some layers that were to be ignored were "
188-
"not found in the model: "
189-
f"{set(config.ignore) - set(ignored_submodules)}"
190-
)
152+
for name, submodule, matched_targets in match_named_modules(
153+
model,
154+
target_to_scheme,
155+
config.ignore or [],
156+
warn_on_fail=True,
157+
warn_on_unmatched_ignores=True,
158+
return_matched_targets=True,
159+
preprocess_name=fix_fsdp_module_name,
160+
):
161+
# mark modules to be quantized by adding
162+
# quant scheme to the matching layers
163+
scheme = _scheme_from_targets(target_to_scheme, matched_targets, name)
164+
if run_compressed:
165+
format = config.format
166+
if format != CompressionFormat.dense.value:
167+
if isinstance(submodule, torch.nn.Linear):
168+
# TODO: expand to more module types
169+
compressed_linear = CompressedLinear.from_linear(
170+
submodule,
171+
quantization_scheme=scheme,
172+
quantization_format=format,
173+
)
174+
replace_module(model, name, compressed_linear)
175+
176+
# target matched - add layer and scheme to target list
177+
submodule.quantization_scheme = scheme
178+
179+
names_to_scheme[name] = submodule.quantization_scheme
191180

192181
# apply current quantization status across all targeted layers
193182
apply_quantization_status(model, config.quantization_status)
@@ -429,7 +418,6 @@ def _scheme_from_targets(
429418
def _merge_schemes(
430419
schemes_to_merge: List[QuantizationScheme], name: str
431420
) -> QuantizationScheme:
432-
433421
kv_cache_quantization_scheme = [
434422
scheme for scheme in schemes_to_merge if is_kv_cache_quant_scheme(scheme)
435423
]

src/compressed_tensors/utils/match.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import logging
1616
import re
1717
from collections.abc import Generator
18-
from typing import Iterable, Tuple
18+
from typing import Callable, Iterable, Tuple
1919

2020
import torch
2121

@@ -36,8 +36,11 @@
3636
def match_named_modules(
3737
model: torch.nn.Module,
3838
targets: Iterable[str],
39-
ignore: Iterable[str] = tuple(),
39+
ignore: Iterable[str] | None = tuple(),
4040
warn_on_fail: bool = False,
41+
warn_on_unmatched_ignores: bool = False,
42+
return_matched_targets: bool = False,
43+
preprocess_name: Callable[[str], str] = lambda x: x,
4144
) -> Generator[Tuple[str, torch.nn.Module]]:
4245
"""
4346
Yields names and modules which match `targets` but do not match `ignore`.
@@ -49,21 +52,66 @@ def match_named_modules(
4952
:param warn_on_fail: if True, warns if any targets do not match any modules in model
5053
:return: generator of module names and modules
5154
"""
55+
ignore = ignore or []
56+
5257
unmatched_targets = set(targets)
58+
unmatched_ignores = set(ignore)
59+
60+
# Order targets by type: exact name match, regex name match, class name match
61+
targets = sorted(targets, key=lambda x: ("re:" in x, x))
5362
for name, module in model.named_modules():
63+
# preprocess the module name and module
64+
name = preprocess_name(name)
65+
66+
ignore_matched = False
67+
for ign in ignore:
68+
if is_match(name, module, ign):
69+
unmatched_ignores -= {ign}
70+
ignore_matched = True
71+
break
72+
if ignore_matched:
73+
continue
74+
75+
matched_targets = []
76+
# Check for name matches first (exact then regex)
77+
for target in targets:
78+
if match_name(name, target):
79+
unmatched_targets -= {target}
80+
matched_targets.append(target)
81+
if not return_matched_targets:
82+
break
83+
84+
if not return_matched_targets and matched_targets:
85+
# Don't need to check other targets, one match is enough
86+
yield name, module
87+
continue
88+
89+
# Check for class matches
5490
for target in targets:
55-
if is_match(name, module, target):
91+
if match_class(module, target):
5692
unmatched_targets -= {target}
93+
matched_targets.append(target)
94+
if not return_matched_targets:
95+
break
5796

58-
if not any(is_match(name, module, ign) for ign in ignore):
59-
yield name, module
97+
if matched_targets:
98+
if return_matched_targets:
99+
yield name, module, matched_targets
100+
else:
101+
yield name, module
60102

61103
if warn_on_fail:
62104
for target in unmatched_targets:
63105
_LOGGER.warning(
64106
f"Could not match `{target}` in instance of {model.__class__.__name__}"
65107
)
66108

109+
if warn_on_unmatched_ignores:
110+
for ign in unmatched_ignores:
111+
_LOGGER.warning(
112+
f"Unmatched ignore targets: {unmatched_ignores}, in instance of {model.__class__.__name__}"
113+
)
114+
67115

68116
def match_named_parameters(
69117
model: torch.nn.Module,

0 commit comments

Comments
 (0)