Skip to content

Commit 8a93d26

Browse files
authored
expand is_match (#416)
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 05e00e4 commit 8a93d26

File tree

2 files changed

+30
-16
lines changed

2 files changed

+30
-16
lines changed

src/compressed_tensors/utils/match.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import logging
1616
import re
1717
from collections.abc import Generator
18-
from typing import Iterable, Mapping, Optional, Tuple
18+
from typing import Iterable, List, Mapping, Optional, Tuple
1919

2020
import torch
2121
from compressed_tensors.utils.internal import InternalModule
@@ -57,10 +57,10 @@ def match_named_modules(
5757
unmatched_targets = set(targets)
5858
for name, module in model.named_modules():
5959
for target in targets:
60-
if is_match(name, module, target, fused):
60+
if is_match(name, module, target, fused=fused):
6161
unmatched_targets -= {target}
6262

63-
if not any(is_match(name, module, ign, fused) for ign in ignore):
63+
if not is_match(name, module, ignore, fused=fused):
6464
yield name, module
6565

6666
if warn_on_fail:
@@ -155,9 +155,7 @@ def match_modules_set(
155155
for name, module in model.named_modules():
156156
# match until we get a full set
157157
for target in targets:
158-
if is_match(name, module, target) and not any(
159-
is_match(name, module, ign) for ign in ignore
160-
):
158+
if is_match(name, module, target, ignore):
161159
if matches[target] is not None:
162160
raise ValueError(f"Matched a {target} twice before completing set")
163161
matches[target] = module
@@ -176,7 +174,8 @@ def match_modules_set(
176174
def is_match(
177175
name: str,
178176
module: torch.nn.Module,
179-
target: str,
177+
targets: str | Iterable[str],
178+
ignore: str | Iterable[str] = tuple(),
180179
fused: Optional[FusedMappping] = None,
181180
) -> bool:
182181
"""
@@ -198,8 +197,17 @@ def is_match(
198197
:fused: optional mapping from suffixes of fused modules to the suffixes of their
199198
corresponding shards
200199
"""
200+
targets = [targets] if isinstance(targets, str) else targets
201+
ignore = [ignore] if isinstance(ignore, str) else ignore
202+
201203
return not isinstance(module, InternalModule) and (
202-
_match_name(name, target, fused) or _match_class(module, target)
204+
any(
205+
_match_name(name, target, fused) or _match_class(module, target)
206+
for target in targets
207+
)
208+
and not any(
209+
_match_name(name, ign, fused) or _match_class(module, ign) for ign in ignore
210+
)
203211
)
204212

205213

tests/test_utils/test_match.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -201,14 +201,20 @@ def test_fused_mapping(self):
201201
"gate_up_proj": ["gate_proj", "up_proj"],
202202
}
203203

204-
assert is_match("dummy.qkv_proj", linear, "re:.*q_proj", mapping) == True
205-
assert is_match("dummy.qkv_proj", linear, "re:.*k_proj", mapping) == True
206-
assert is_match("dummy.qkv_proj", linear, "re:.*v_proj", mapping) == True
207-
assert is_match("dummy.qkv_proj", linear, "Linear", mapping) == True
208-
209-
assert is_match("dummy.gate_up_proj", linear, "re:.*gate_proj", mapping) == True
210-
assert is_match("dummy.gate_up_proj", linear, "re:.*up_proj", mapping) == True
211-
assert is_match("dummy.gate_up_proj", linear, "Linear", mapping) == True
204+
assert is_match("dummy.qkv_proj", linear, "re:.*q_proj", fused=mapping) == True
205+
assert is_match("dummy.qkv_proj", linear, "re:.*k_proj", fused=mapping) == True
206+
assert is_match("dummy.qkv_proj", linear, "re:.*v_proj", fused=mapping) == True
207+
assert is_match("dummy.qkv_proj", linear, "Linear", fused=mapping) == True
208+
209+
assert (
210+
is_match("dummy.gate_up_proj", linear, "re:.*gate_proj", fused=mapping)
211+
== True
212+
)
213+
assert (
214+
is_match("dummy.gate_up_proj", linear, "re:.*up_proj", fused=mapping)
215+
== True
216+
)
217+
assert is_match("dummy.gate_up_proj", linear, "Linear", fused=mapping) == True
212218

213219

214220
class TestMatchNamedModules:

0 commit comments

Comments
 (0)