Skip to content

Commit 4283b1d

Browse files
committed
Update match.py util fn signatures and small fixes
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
1 parent b1fa4df commit 4283b1d

File tree

1 file changed

+30
-14
lines changed

1 file changed

+30
-14
lines changed

src/compressed_tensors/utils/match.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
__all__ = [
2828
"match_named_modules",
2929
"match_named_parameters",
30+
"match_targets",
3031
"match_modules_set",
3132
"is_match",
3233
]
@@ -46,25 +47,19 @@ def match_named_modules(
4647
:param targets: target strings, potentially containing "re:" prefixes
4748
:param ignore: targets to ignore, potentially containing "re:" prefixes
4849
:param warn_on_fail: if True, warns if any targets do not match any modules in model
49-
:param preprocess_name: a function to preprocess the module name
5050
:return: generator of module names and modules
5151
"""
52-
ignore = ignore or []
5352
targets = targets or []
53+
ignore = ignore or []
5454

5555
unmatched_targets = set(targets)
5656

5757
for name, module in model.named_modules():
58-
if isinstance(module, InternalModule):
59-
continue
60-
61-
if any(is_match(name, module, ign) for ign in ignore):
62-
continue
63-
6458
for target in targets:
6559
if is_match(name, module, target):
6660
unmatched_targets -= {target}
67-
yield name, module
61+
if not any(is_match(name, module, ign) for ign in ignore):
62+
yield name, module
6863
break
6964

7065
if warn_on_fail:
@@ -76,8 +71,8 @@ def match_named_modules(
7671

7772
def match_named_parameters(
7873
model: torch.nn.Module,
79-
targets: Iterable[str],
80-
ignore: Iterable[str] = tuple(),
74+
targets: Iterable[str] | None = None,
75+
ignore: Iterable[str] | None = None,
8176
warn_on_fail: bool = False,
8277
) -> Generator[Tuple[str, torch.nn.Module, torch.nn.Parameter]]:
8378
"""
@@ -90,6 +85,9 @@ def match_named_parameters(
9085
:param warn_on_fail: if True, warns if any targets do not match any params in model
9186
:return: generator of fully-qualified param names, parent modules, and params
9287
"""
88+
targets = targets or []
89+
ignore = ignore or []
90+
9391
unmatched_targets = set(targets)
9492
for module_name, module in model.named_modules():
9593
if isinstance(module, InternalModule):
@@ -112,15 +110,30 @@ def match_named_parameters(
112110

113111

114112
def match_targets(
115-
name: str, module: torch.nn.Module, targets: Iterable[str]
113+
name: str, module: torch.nn.Module, targets: Iterable[str] | None = None
116114
) -> List[str]:
117115
"""
118116
Returns the targets that match the given name and module.
117+
118+
:param name: the name of the module
119+
:param module: the module to match
120+
:param targets: the target strings, potentially containing "re:" prefixes
121+
:return: the targets that match the given name and module
122+
119123
Outputs are ordered by type: exact name match, regex name match, class name match
120124
"""
125+
targets = targets or []
126+
121127
if isinstance(module, InternalModule):
122128
return []
123129

130+
# The order of the output `matches` list matters, the are arranged from most
131+
# specific to least specific, and this order will be used when merging configs.
132+
# The entries are sorted in the following order:
133+
# 1. matches on exact strings
134+
# 2. matches on regex patterns
135+
# 3. matches on module names
136+
124137
targets = sorted(targets, key=lambda x: ("re:" in x, x))
125138
matched_targets = []
126139
for target in targets:
@@ -136,8 +149,8 @@ def match_targets(
136149

137150
def match_modules_set(
138151
model: torch.nn.Module,
139-
targets: Iterable[str],
140-
ignore: Iterable[str] = tuple(),
152+
targets: Iterable[str] | None = None,
153+
ignore: Iterable[str] | None = None,
141154
) -> Generator[Iterable[torch.nn.Module]]:
142155
"""
143156
Yields modules grouped with the same order and size as `targets`.
@@ -175,6 +188,9 @@ def match_modules_set(
175188
:param targets: target strings, potentially containing "re:" prefixes
176189
:param ignore: targets to ignore, potentially containing "re:" prefixes
177190
"""
191+
targets = targets or []
192+
ignore = ignore or []
193+
178194
matches = dict.fromkeys(targets, None)
179195
for name, module in model.named_modules():
180196
# match until we get a full set

0 commit comments

Comments
 (0)