Skip to content

Commit 0f872ef

Browse files
committed
merge in canonical matching utils
Signed-off-by: Kyle Sayers <[email protected]>
1 parent e93a705 commit 0f872ef

File tree

1 file changed

+131
-9
lines changed

1 file changed

+131
-9
lines changed

src/compressed_tensors/utils/match.py

Lines changed: 131 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import logging
1616
import re
17+
from collections import OrderedDict
1718
from collections.abc import Generator
1819
from typing import Iterable, Tuple
1920

@@ -23,20 +24,37 @@
2324
_LOGGER: logging.Logger = logging.getLogger(__name__)
2425

2526

26-
__all__ = ["match_named_modules", "is_match"]
27+
__all__ = [
28+
"match_named_modules",
29+
"match_named_parameters",
30+
"match_modules_set",
31+
"is_match",
32+
"match_name",
33+
"match_class",
34+
]
2735

2836

2937
def match_named_modules(
3038
model: torch.nn.Module,
3139
targets: Iterable[str] = tuple(),
3240
ignore: Iterable[str] = tuple(),
33-
warn_on_fail: bool = True,
34-
) -> Generator[Tuple[str, torch.nn.Module], None, None]:
41+
warn_on_fail: bool = False,
42+
) -> Generator[Tuple[str, torch.nn.Module]]:
43+
"""
44+
Yields names and modules which match `targets` but do not match `ignore`.
45+
Values are returned in order of `model.named_modules()`
46+
47+
:param model: model containing submodules to match against
48+
:param targets: target strings, potentially containing "re:" prefixes
49+
:param ignore: targets to ignore, potentially containing "re:" prefixes
50+
:param warn_on_fail: if True, warns if any targets do not match any modules in model
51+
:return: generator of module names and modules
52+
"""
3553
unmatched_targets = set(targets)
3654
for name, module in model.named_modules():
3755
for target in targets:
3856
if is_match(name, module, target):
39-
unmatched_targets.remove(target)
57+
unmatched_targets -= {target}
4058

4159
if not any(is_match(name, module, ign) for ign in ignore):
4260
yield name, module
@@ -48,22 +66,126 @@ def match_named_modules(
4866
)
4967

5068

69+
def match_named_parameters(
70+
model: torch.nn.Module,
71+
targets: Iterable[str],
72+
ignore: Iterable[str],
73+
warn_on_fail: bool = False,
74+
) -> Generator[Tuple[str, torch.nn.Module, torch.nn.Parameter]]:
75+
"""
76+
Yields parameters which match `targets` but do not match `ignore`.
77+
Values are returned in order of `model.named_modules()`
78+
79+
:param model: model containing params to match against
80+
:param targets: target strings, potentially containing "re:" prefixes
81+
:param ignore: targets to ignore, potentially containing "re:" prefixes
82+
:param warn_on_fail: if True, warns if any targets do not match any params in model
83+
:return: generator of fully-qualified param names, parent modules, and params
84+
"""
85+
unmatched_targets = set(targets)
86+
for module_name, module in model.named_modules():
87+
for param_name, param in module.named_parameters(recurse=False):
88+
param_fqn = f"{module_name}.{param_name}"
89+
for target in targets:
90+
if match_name(param_fqn, target):
91+
unmatched_targets -= {target}
92+
93+
if not any(match_name(param_fqn, ign) for ign in ignore):
94+
yield param_fqn, module, param
95+
96+
if warn_on_fail:
97+
for target in unmatched_targets:
98+
_LOGGER.warning(
99+
f"Could not match `{target}` in instance of {model.__class__.__name__}"
100+
)
101+
102+
103+
def match_modules_set(
104+
model: torch.nn.Module,
105+
targets: Iterable[str],
106+
ignore: Iterable[str],
107+
) -> Generator[Iterable[torch.nn.Module]]:
108+
"""
109+
Yields modules grouped with the same order and size as `targets`.
110+
Values are returned in order of `model.named_modules()`
111+
112+
For example, the following targets would yield module belonging to the following layers:
113+
```python3
114+
match_modules_set(model, ["q_proj", "k_proj", "v_proj"]) == (
115+
(
116+
`model.layers.0.self_attn.q_proj`,
117+
`model.layers.0.self_attn.k_proj`,
118+
`model.layers.0.self_attn.v_proj`,
119+
),
120+
(
121+
`model.layers.1.self_attn.q_proj`,
122+
`model.layers.1.self_attn.k_proj`,
123+
`model.layers.1.self_attn.v_proj`,
124+
),
125+
...
126+
(
127+
`model.layers.32.self_attn.q_proj`,
128+
`model.layers.32.self_attn.k_proj`,
129+
`model.layers.32.self_attn.v_proj`,
130+
),
131+
)
132+
```
133+
134+
This can be used to match layers to their corresponding downstream counterparts.
135+
For example, matching layer norms to their subsequent linear layers
136+
```python3
137+
for norm, q, k, v in match_modules_set(model, (norm_tgt, q_tgt, k_tgt, v_tgt)):
138+
fuse_norm_linears(norm, [q, k, v])
139+
140+
:param model: model containing modules to match against
141+
:param targets: target strings, potentially containing "re:" prefixes
142+
:param ignore: targets to ignore, potentially containing "re:" prefixes
143+
"""
144+
matches = dict.fromkeys(targets, None)
145+
for name, module in model.named_modules():
146+
# match until we get a full set
147+
for target in targets:
148+
if is_match(name, module, target) and not any(
149+
is_match(name, module, ign) for ign in ignore
150+
):
151+
if matches[target] is not None:
152+
raise ValueError(f"Matched a {target} twice before completing set")
153+
matches[target] = module
154+
155+
# once we have a full set, yield and reset
156+
if all(matches[target] is not None for target in targets):
157+
yield [matches[target] for target in targets] # ensure correct ordering
158+
matches = dict.fromkeys(targets, None)
159+
160+
# check that none are left over
161+
unmatched_keys = [match for match, value in matches.items() if value is None]
162+
if len(unmatched_keys):
163+
raise ValueError(f"Unable to match targets into set: {unmatched_keys}")
164+
165+
51166
def is_match(name: str, module: torch.nn.Module, target: str) -> bool:
52-
return _match_name(name, target) or _match_class(module, target)
167+
"""
168+
Returns true if either module name or module parent classes match against target
169+
"""
170+
return match_name(name, target) or match_class(module, target)
53171

54172

55-
def _match_name(name: str, target: str) -> bool:
173+
def match_name(name: str, target: str) -> bool:
174+
"""
175+
Returns true if target string begins with "re:" and
176+
regex matches or if target string exactly matches name
177+
"""
56178
if target.startswith("re:"):
57179
return re.match(target.removeprefix("re:"), name)
58180
else:
59181
return target == name
60182

61183

62-
def _match_class(module: torch.nn.Module, target: str) -> bool:
184+
def match_class(module: torch.nn.Module, target: str) -> bool:
63185
"""
64-
Will never match against a regex pattern since `:` is not allowed in class names
65-
186+
Returns true if any torch parent class names match the target string exactly
66187
"""
188+
# will never match against a regex pattern since `:` is not allowed in class names
67189
return any(
68190
issubclass(cls, torch.nn.Module) and cls.__name__ == target
69191
for cls in module.__class__.__mro__

0 commit comments

Comments
 (0)