Skip to content

Commit bd4a991

Browse files
committed
use matching utils
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 073d8d7 commit bd4a991

File tree

2 files changed

+55
-0
lines changed

2 files changed

+55
-0
lines changed

src/compressed_tensors/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@
2020
from .permute import *
2121
from .safetensors_load import *
2222
from .semi_structured_conversions import *
23+
from .match import *

src/compressed_tensors/utils/match.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from typing import Iterable, Tuple
2+
from collections.abc import Generator
3+
4+
import re
5+
import torch
6+
import logging
7+
8+
_LOGGER: logging.Logger = logging.getLogger(__name__)
9+
10+
11+
__all__ = ["match_named_modules", "is_match"]
12+
13+
14+
def match_named_modules(
15+
model: torch.nn.Module,
16+
targets: Iterable[str] = tuple(),
17+
ignore: Iterable[str] = tuple(),
18+
warn_on_fail: bool = True
19+
) -> Generator[Tuple[str, torch.nn.Module], None, None]:
20+
unmatched_targets = set(targets)
21+
for name, module in model.named_modules():
22+
for target in targets:
23+
if is_match(name, module, target):
24+
unmatched_targets.remove(target)
25+
26+
if not any(is_match(name, module, ign) for ign in ignore):
27+
yield name, module
28+
29+
if warn_on_fail:
30+
for target in unmatched_targets:
31+
_LOGGER.warning(
32+
f"Could not match `{target}` in instance of {model.__class__.__name__}"
33+
)
34+
35+
def is_match(name: str, module: torch.nn.Module, target: str) -> bool:
36+
return _match_name(name, target) or _match_class(module, target)
37+
38+
39+
def _match_name(name: str, target: str) -> bool:
40+
if target.startswith("re:"):
41+
return re.match(target.removeprefix("re:"), name)
42+
else:
43+
return target == name
44+
45+
46+
def _match_class(module: torch.nn.Module, target: str) -> bool:
47+
"""
48+
Will never match against a regex pattern since `:` is not allowed in class names
49+
50+
"""
51+
return any(
52+
issubclass(cls, torch.nn.Module) and cls.__name__ == target
53+
for cls in module.__class__.__mro__
54+
)

0 commit comments

Comments
 (0)