@@ -37,8 +37,6 @@ def match_named_modules(
37
37
targets : Iterable [str ] | None ,
38
38
ignore : Iterable [str ] | None = None ,
39
39
warn_on_fail : bool = False ,
40
- warn_on_unmatched_ignores : bool = False ,
41
- yield_matched_targets : bool = False ,
42
40
preprocess_name : Callable [[str ], str ] = lambda x : x ,
43
41
) -> Generator [Tuple [str , torch .nn .Module ] | Tuple [str , torch .nn .Module , List [str ]]]:
44
42
"""
@@ -49,70 +47,36 @@ def match_named_modules(
49
47
:param targets: target strings, potentially containing "re:" prefixes
50
48
:param ignore: targets to ignore, potentially containing "re:" prefixes
51
49
:param warn_on_fail: if True, warns if any targets do not match any modules in model
52
- :param warn_on_unmatched_ignores: if True, warns if any ignores do not match any modules in model
53
- :param yield_matched_targets: if True, yields the matched targets in addition to the module name and module
54
50
:param preprocess_name: a function to preprocess the module name
55
51
:return: generator of module names and modules
56
52
"""
57
53
ignore = ignore or []
58
54
targets = targets or []
59
55
60
56
unmatched_targets = set (targets )
61
- unmatched_ignores = set (ignore )
62
57
63
- # Note: when yield_matched_targets is True, the ordering of the targets is important
64
- # Order targets by type: exact name match, regex name match, class name match
65
- targets = sorted (targets , key = lambda x : ("re:" in x , x ))
66
58
for name , module in model .named_modules ():
67
59
if isinstance (module , InternalModule ):
68
60
continue
69
61
70
62
# preprocess the module name and module
71
63
name = preprocess_name (name )
72
64
73
- ignore_matched = False
74
- for ign in ignore :
75
- if is_match (name , module , ign ):
76
- unmatched_ignores -= {ign }
77
- ignore_matched = True
78
- break
79
- if ignore_matched :
65
+ if any (is_match (name , module , ign ) for ign in ignore ):
80
66
continue
81
67
82
- matched_target_on_name = []
83
- matched_target_on_class = []
84
- # Check for name matches first (exact then regex, enforced by sort above)
85
68
for target in targets :
86
- if _match_name (name , target ):
69
+ if is_match (name , module , target ):
87
70
unmatched_targets -= {target }
88
- matched_target_on_name .append (target )
89
- if not yield_matched_targets :
90
- break
91
- elif _match_class (module , target ):
92
- unmatched_targets -= {target }
93
- matched_target_on_class .append (target )
94
- if not yield_matched_targets :
95
- break
96
-
97
- matched_targets = matched_target_on_name + matched_target_on_class
98
- if matched_targets :
99
- if yield_matched_targets :
100
- yield name , module , matched_targets
101
- else :
102
71
yield name , module
72
+ break
103
73
104
74
if warn_on_fail :
105
75
for target in unmatched_targets :
106
76
_LOGGER .warning (
107
77
f"Could not match `{ target } ` in instance of { model .__class__ .__name__ } "
108
78
)
109
79
110
- if warn_on_unmatched_ignores :
111
- for ign in unmatched_ignores :
112
- _LOGGER .warning (
113
- f"Unmatched ignore targets: { unmatched_ignores } , in instance of { model .__class__ .__name__ } "
114
- )
115
-
116
80
117
81
def match_named_parameters (
118
82
model : torch .nn .Module ,
@@ -151,6 +115,23 @@ def match_named_parameters(
151
115
)
152
116
153
117
118
+ def match_targets (
119
+ name : str , module : torch .nn .Module , targets : Iterable [str ]
120
+ ) -> Generator [str ]:
121
+ """
122
+ Yields the targets that match the given name and module.
123
+ Outputs are ordered by type: exact name match, regex name match, class name match
124
+ """
125
+ targets = sorted (targets , key = lambda x : ("re:" in x , x ))
126
+ for target in targets :
127
+ if _match_name (name , target ):
128
+ yield target
129
+
130
+ for target in targets :
131
+ if _match_class (module , target ):
132
+ yield target
133
+
134
+
154
135
def match_modules_set (
155
136
model : torch .nn .Module ,
156
137
targets : Iterable [str ],
0 commit comments