15
15
import logging
16
16
import re
17
17
from collections .abc import Generator
18
- from typing import Callable , Iterable , Tuple
18
+ from typing import Callable , Iterable , List , Tuple
19
19
20
20
import torch
21
21
from compressed_tensors .utils .internal import InternalModule
@@ -38,9 +38,9 @@ def match_named_modules(
38
38
ignore : Iterable [str ] | None = None ,
39
39
warn_on_fail : bool = False ,
40
40
warn_on_unmatched_ignores : bool = False ,
41
- return_matched_targets : bool = False ,
41
+ yield_matched_targets : bool = False ,
42
42
preprocess_name : Callable [[str ], str ] = lambda x : x ,
43
- ) -> Generator [Tuple [str , torch .nn .Module ]]:
43
+ ) -> Generator [Tuple [str , torch .nn .Module ] | Tuple [ str , torch . nn . Module , List [ str ]] ]:
44
44
"""
45
45
Yields names and modules which match `targets` but do not match `ignore`.
46
46
Values are returned in order of `model.named_modules()`
@@ -49,6 +49,9 @@ def match_named_modules(
49
49
:param targets: target strings, potentially containing "re:" prefixes
50
50
:param ignore: targets to ignore, potentially containing "re:" prefixes
51
51
:param warn_on_fail: if True, warns if any targets do not match any modules in model
52
+ :param warn_on_unmatched_ignores: if True, warns if any ignores do not match any modules in model
53
+ :param yield_matched_targets: if True, yields the matched targets in addition to the module name and module
54
+ :param preprocess_name: a function to preprocess the module name
52
55
:return: generator of module names and modules
53
56
"""
54
57
ignore = ignore or []
@@ -57,6 +60,7 @@ def match_named_modules(
57
60
unmatched_targets = set (targets )
58
61
unmatched_ignores = set (ignore )
59
62
63
+ # Note: when yield_matched_targets is True, the ordering of the targets is important
60
64
# Order targets by type: exact name match, regex name match, class name match
61
65
targets = sorted (targets , key = lambda x : ("re:" in x , x ))
62
66
for name , module in model .named_modules ():
@@ -75,30 +79,24 @@ def match_named_modules(
75
79
if ignore_matched :
76
80
continue
77
81
78
- matched_targets = []
79
- # Check for name matches first (exact then regex)
82
+ matched_target_on_name = []
83
+ matched_target_on_class = []
84
+ # Check for name matches first (exact then regex, enforced by sort above)
80
85
for target in targets :
81
86
if _match_name (name , target ):
82
87
unmatched_targets -= {target }
83
- matched_targets .append (target )
84
- if not return_matched_targets :
88
+ matched_target_on_name .append (target )
89
+ if not yield_matched_targets :
85
90
break
86
-
87
- if not return_matched_targets and matched_targets :
88
- # Don't need to check other targets, one match is enough
89
- yield name , module
90
- continue
91
-
92
- # Check for class matches
93
- for target in targets :
94
- if _match_class (module , target ):
91
+ elif _match_class (module , target ):
95
92
unmatched_targets -= {target }
96
- matched_targets .append (target )
97
- if not return_matched_targets :
93
+ matched_target_on_class .append (target )
94
+ if not yield_matched_targets :
98
95
break
99
96
97
+ matched_targets = matched_target_on_name + matched_target_on_class
100
98
if matched_targets :
101
- if return_matched_targets :
99
+ if yield_matched_targets :
102
100
yield name , module , matched_targets
103
101
else :
104
102
yield name , module
0 commit comments