|
| 1 | +from typing import Iterable, Tuple |
| 2 | +from collections.abc import Generator |
| 3 | + |
| 4 | +import re |
| 5 | +import torch |
| 6 | +import logging |
| 7 | + |
| 8 | +_LOGGER: logging.Logger = logging.getLogger(__name__) |
| 9 | + |
| 10 | + |
| 11 | +__all__ = ["match_named_modules", "is_match"] |
| 12 | + |
| 13 | + |
| 14 | +def match_named_modules( |
| 15 | + model: torch.nn.Module, |
| 16 | + targets: Iterable[str] = tuple(), |
| 17 | + ignore: Iterable[str] = tuple(), |
| 18 | + warn_on_fail: bool = True |
| 19 | +) -> Generator[Tuple[str, torch.nn.Module], None, None]: |
| 20 | + unmatched_targets = set(targets) |
| 21 | + for name, module in model.named_modules(): |
| 22 | + for target in targets: |
| 23 | + if is_match(name, module, target): |
| 24 | + unmatched_targets.remove(target) |
| 25 | + |
| 26 | + if not any(is_match(name, module, ign) for ign in ignore): |
| 27 | + yield name, module |
| 28 | + |
| 29 | + if warn_on_fail: |
| 30 | + for target in unmatched_targets: |
| 31 | + _LOGGER.warning( |
| 32 | + f"Could not match `{target}` in instance of {model.__class__.__name__}" |
| 33 | + ) |
| 34 | + |
| 35 | +def is_match(name: str, module: torch.nn.Module, target: str) -> bool: |
| 36 | + return _match_name(name, target) or _match_class(module, target) |
| 37 | + |
| 38 | + |
| 39 | +def _match_name(name: str, target: str) -> bool: |
| 40 | + if target.startswith("re:"): |
| 41 | + return re.match(target.removeprefix("re:"), name) |
| 42 | + else: |
| 43 | + return target == name |
| 44 | + |
| 45 | + |
| 46 | +def _match_class(module: torch.nn.Module, target: str) -> bool: |
| 47 | + """ |
| 48 | + Will never match against a regex pattern since `:` is not allowed in class names |
| 49 | + |
| 50 | + """ |
| 51 | + return any( |
| 52 | + issubclass(cls, torch.nn.Module) and cls.__name__ == target |
| 53 | + for cls in module.__class__.__mro__ |
| 54 | + ) |
0 commit comments