Skip to content

Commit 7e0dc32

Browse files
authored
[Transform] [Utils] Canonical matching utilities (#392)
* matching utilities Signed-off-by: Kyle Sayers <[email protected]> * match_named_parameters Signed-off-by: Kyle Sayers <[email protected]> * fix typo Signed-off-by: Kyle Sayers <[email protected]> * use match_named_modules Signed-off-by: Kyle Sayers <[email protected]> * implement match_modules_set Signed-off-by: Kyle Sayers <[email protected]> * proper defaults Signed-off-by: Kyle Sayers <[email protected]> * fix typo Signed-off-by: Kyle Sayers <[email protected]> * add tests, fix bug Signed-off-by: Kyle Sayers <[email protected]> --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent 09b7ed4 commit 7e0dc32

File tree

6 files changed

+629
-7
lines changed

6 files changed

+629
-7
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -754,8 +754,8 @@ def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]:
754754
fix_fsdp_module_name(name): module.quantization_scheme
755755
for name, module in model.named_modules()
756756
if (
757-
hasattr(module, "quantization_scheme") and
758-
module.quantization_scheme.weights is not None
757+
hasattr(module, "quantization_scheme")
758+
and module.quantization_scheme.weights is not None
759759
)
760760
}
761761

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,12 @@ def _initialize_scale_zero_point(
189189
else:
190190
# TODO: consider erroring out in the future as if the dtype if not one of these,
191191
# there is likely bug
192-
if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32, torch.float64]:
192+
if scale_dtype not in [
193+
torch.float16,
194+
torch.bfloat16,
195+
torch.float32,
196+
torch.float64,
197+
]:
193198
scale_dtype = torch.float16
194199
zp_dtype = quantization_args.pytorch_dtype()
195200

src/compressed_tensors/transform/factory/base.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import torch
1919
import torch.nn.utils.parametrize as P
2020
from compressed_tensors import InternalModule
21-
from compressed_tensors.quantization.lifecycle import is_target # TODO: move to utils
2221
from compressed_tensors.registry.registry import RegistryMixin, T
2322
from compressed_tensors.transform import (
2423
TransformArgs,
@@ -29,6 +28,7 @@
2928
align_module_device,
3029
delete_offload_module,
3130
has_offloaded_params,
31+
match_named_modules,
3232
patch_attr,
3333
register_offload_module,
3434
update_offload_parameter,
@@ -87,9 +87,8 @@ def apply_to_model(self, model: Module):
8787
:param model: module to apply transforms to
8888
"""
8989
for arg in self.scheme.apply:
90-
for name, module in list(model.named_modules()):
91-
if is_target(name, module, arg.targets, arg.ignore):
92-
self._apply_to_module(module, arg)
90+
for _, module in match_named_modules(model, arg.targets, arg.ignore):
91+
self._apply_to_module(module, arg)
9392

9493
def _apply_to_module(self, module: Module, args: TransformArgs):
9594
"""

src/compressed_tensors/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from .helpers import *
1717
from .internal import *
18+
from .match import *
1819
from .offload import *
1920
from .permutations_24 import *
2021
from .permute import *

src/compressed_tensors/utils/match.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
import re
17+
from collections.abc import Generator
18+
from typing import Iterable, Tuple
19+
20+
import torch
21+
22+
23+
_LOGGER: logging.Logger = logging.getLogger(__name__)
24+
25+
26+
__all__ = [
27+
"match_named_modules",
28+
"match_named_parameters",
29+
"match_modules_set",
30+
"is_match",
31+
"match_name",
32+
"match_class",
33+
]
34+
35+
36+
def match_named_modules(
37+
model: torch.nn.Module,
38+
targets: Iterable[str],
39+
ignore: Iterable[str] = tuple(),
40+
warn_on_fail: bool = False,
41+
) -> Generator[Tuple[str, torch.nn.Module]]:
42+
"""
43+
Yields names and modules which match `targets` but do not match `ignore`.
44+
Values are returned in order of `model.named_modules()`
45+
46+
:param model: model containing submodules to match against
47+
:param targets: target strings, potentially containing "re:" prefixes
48+
:param ignore: targets to ignore, potentially containing "re:" prefixes
49+
:param warn_on_fail: if True, warns if any targets do not match any modules in model
50+
:return: generator of module names and modules
51+
"""
52+
unmatched_targets = set(targets)
53+
for name, module in model.named_modules():
54+
for target in targets:
55+
if is_match(name, module, target):
56+
unmatched_targets -= {target}
57+
58+
if not any(is_match(name, module, ign) for ign in ignore):
59+
yield name, module
60+
61+
if warn_on_fail:
62+
for target in unmatched_targets:
63+
_LOGGER.warning(
64+
f"Could not match `{target}` in instance of {model.__class__.__name__}"
65+
)
66+
67+
68+
def match_named_parameters(
69+
model: torch.nn.Module,
70+
targets: Iterable[str],
71+
ignore: Iterable[str] = tuple(),
72+
warn_on_fail: bool = False,
73+
) -> Generator[Tuple[str, torch.nn.Module, torch.nn.Parameter]]:
74+
"""
75+
Yields parameters which match `targets` but do not match `ignore`.
76+
Values are returned in order of `model.named_modules()`
77+
78+
:param model: model containing params to match against
79+
:param targets: target strings, potentially containing "re:" prefixes
80+
:param ignore: targets to ignore, potentially containing "re:" prefixes
81+
:param warn_on_fail: if True, warns if any targets do not match any params in model
82+
:return: generator of fully-qualified param names, parent modules, and params
83+
"""
84+
unmatched_targets = set(targets)
85+
for module_name, module in model.named_modules():
86+
for param_name, param in module.named_parameters(recurse=False):
87+
param_fqn = f"{module_name}.{param_name}"
88+
for target in targets:
89+
if match_name(param_fqn, target):
90+
unmatched_targets -= {target}
91+
92+
if not any(match_name(param_fqn, ign) for ign in ignore):
93+
yield param_fqn, module, param
94+
95+
if warn_on_fail:
96+
for target in unmatched_targets:
97+
_LOGGER.warning(
98+
f"Could not match `{target}` in instance of {model.__class__.__name__}"
99+
)
100+
101+
102+
def match_modules_set(
103+
model: torch.nn.Module,
104+
targets: Iterable[str],
105+
ignore: Iterable[str] = tuple(),
106+
) -> Generator[Iterable[torch.nn.Module]]:
107+
"""
108+
Yields modules grouped with the same order and size as `targets`.
109+
Values are returned in order of `model.named_modules()`
110+
111+
For example, the following targets would yield module belonging to the following layers:
112+
```python3
113+
match_modules_set(model, ["q_proj", "k_proj", "v_proj"]) == (
114+
(
115+
`model.layers.0.self_attn.q_proj`,
116+
`model.layers.0.self_attn.k_proj`,
117+
`model.layers.0.self_attn.v_proj`,
118+
),
119+
(
120+
`model.layers.1.self_attn.q_proj`,
121+
`model.layers.1.self_attn.k_proj`,
122+
`model.layers.1.self_attn.v_proj`,
123+
),
124+
...
125+
(
126+
`model.layers.32.self_attn.q_proj`,
127+
`model.layers.32.self_attn.k_proj`,
128+
`model.layers.32.self_attn.v_proj`,
129+
),
130+
)
131+
```
132+
133+
This can be used to match layers to their corresponding downstream counterparts.
134+
For example, matching layer norms to their subsequent linear layers
135+
```python3
136+
for norm, q, k, v in match_modules_set(model, (norm_tgt, q_tgt, k_tgt, v_tgt)):
137+
fuse_norm_linears(norm, [q, k, v])
138+
139+
:param model: model containing modules to match against
140+
:param targets: target strings, potentially containing "re:" prefixes
141+
:param ignore: targets to ignore, potentially containing "re:" prefixes
142+
"""
143+
matches = dict.fromkeys(targets, None)
144+
for name, module in model.named_modules():
145+
# match until we get a full set
146+
for target in targets:
147+
if is_match(name, module, target) and not any(
148+
is_match(name, module, ign) for ign in ignore
149+
):
150+
if matches[target] is not None:
151+
raise ValueError(f"Matched a {target} twice before completing set")
152+
matches[target] = module
153+
154+
# once we have a full set, yield and reset
155+
if targets and all((matches[target] is not None for target in targets)):
156+
yield [matches[target] for target in targets] # ensure correct ordering
157+
matches = dict.fromkeys(targets, None)
158+
159+
# check that none are left over
160+
unmatched_keys = [match for match, value in matches.items() if value is not None]
161+
if len(unmatched_keys):
162+
raise ValueError(f"Unable to match targets into set: {unmatched_keys}")
163+
164+
165+
def is_match(name: str, module: torch.nn.Module, target: str) -> bool:
166+
"""
167+
Returns true if either module name or module parent classes match against target
168+
"""
169+
return match_name(name, target) or match_class(module, target)
170+
171+
172+
def match_name(name: str, target: str) -> bool:
173+
"""
174+
Returns true if target string begins with "re:" and
175+
regex matches or if target string exactly matches name
176+
"""
177+
if target.startswith("re:"):
178+
return re.match(target.removeprefix("re:"), name) is not None
179+
else:
180+
return target == name
181+
182+
183+
def match_class(module: torch.nn.Module, target: str) -> bool:
184+
"""
185+
Returns true if any torch parent class names match the target string exactly
186+
"""
187+
# will never match against a regex pattern since `:` is not allowed in class names
188+
return any(
189+
issubclass(cls, torch.nn.Module) and cls.__name__ == target
190+
for cls in module.__class__.__mro__
191+
)

0 commit comments

Comments
 (0)