27
27
__all__ = [
28
28
"match_named_modules" ,
29
29
"match_named_parameters" ,
30
+ "match_targets" ,
30
31
"match_modules_set" ,
31
32
"is_match" ,
32
33
]
@@ -46,25 +47,19 @@ def match_named_modules(
46
47
:param targets: target strings, potentially containing "re:" prefixes
47
48
:param ignore: targets to ignore, potentially containing "re:" prefixes
48
49
:param warn_on_fail: if True, warns if any targets do not match any modules in model
49
- :param preprocess_name: a function to preprocess the module name
50
50
:return: generator of module names and modules
51
51
"""
52
- ignore = ignore or []
53
52
targets = targets or []
53
+ ignore = ignore or []
54
54
55
55
unmatched_targets = set (targets )
56
56
57
57
for name , module in model .named_modules ():
58
- if isinstance (module , InternalModule ):
59
- continue
60
-
61
- if any (is_match (name , module , ign ) for ign in ignore ):
62
- continue
63
-
64
58
for target in targets :
65
59
if is_match (name , module , target ):
66
60
unmatched_targets -= {target }
67
- yield name , module
61
+ if not any (is_match (name , module , ign ) for ign in ignore ):
62
+ yield name , module
68
63
break
69
64
70
65
if warn_on_fail :
@@ -76,8 +71,8 @@ def match_named_modules(
76
71
77
72
def match_named_parameters (
78
73
model : torch .nn .Module ,
79
- targets : Iterable [str ],
80
- ignore : Iterable [str ] = tuple () ,
74
+ targets : Iterable [str ] | None = None ,
75
+ ignore : Iterable [str ] | None = None ,
81
76
warn_on_fail : bool = False ,
82
77
) -> Generator [Tuple [str , torch .nn .Module , torch .nn .Parameter ]]:
83
78
"""
@@ -90,6 +85,9 @@ def match_named_parameters(
90
85
:param warn_on_fail: if True, warns if any targets do not match any params in model
91
86
:return: generator of fully-qualified param names, parent modules, and params
92
87
"""
88
+ targets = targets or []
89
+ ignore = ignore or []
90
+
93
91
unmatched_targets = set (targets )
94
92
for module_name , module in model .named_modules ():
95
93
if isinstance (module , InternalModule ):
@@ -112,15 +110,30 @@ def match_named_parameters(
112
110
113
111
114
112
def match_targets (
115
- name : str , module : torch .nn .Module , targets : Iterable [str ]
113
+ name : str , module : torch .nn .Module , targets : Iterable [str ] | None = None
116
114
) -> List [str ]:
117
115
"""
118
116
Returns the targets that match the given name and module.
117
+
118
+ :param name: the name of the module
119
+ :param module: the module to match
120
+ :param targets: the target strings, potentially containing "re:" prefixes
121
+ :return: the targets that match the given name and module
122
+
119
123
Outputs are ordered by type: exact name match, regex name match, class name match
120
124
"""
125
+ targets = targets or []
126
+
121
127
if isinstance (module , InternalModule ):
122
128
return []
123
129
130
+ # The order of the output `matches` list matters, the are arranged from most
131
+ # specific to least specific, and this order will be used when merging configs.
132
+ # The entries are sorted in the following order:
133
+ # 1. matches on exact strings
134
+ # 2. matches on regex patterns
135
+ # 3. matches on module names
136
+
124
137
targets = sorted (targets , key = lambda x : ("re:" in x , x ))
125
138
matched_targets = []
126
139
for target in targets :
@@ -136,8 +149,8 @@ def match_targets(
136
149
137
150
def match_modules_set (
138
151
model : torch .nn .Module ,
139
- targets : Iterable [str ],
140
- ignore : Iterable [str ] = tuple () ,
152
+ targets : Iterable [str ] | None = None ,
153
+ ignore : Iterable [str ] | None = None ,
141
154
) -> Generator [Iterable [torch .nn .Module ]]:
142
155
"""
143
156
Yields modules grouped with the same order and size as `targets`.
@@ -175,6 +188,9 @@ def match_modules_set(
175
188
:param targets: target strings, potentially containing "re:" prefixes
176
189
:param ignore: targets to ignore, potentially containing "re:" prefixes
177
190
"""
191
+ targets = targets or []
192
+ ignore = ignore or []
193
+
178
194
matches = dict .fromkeys (targets , None )
179
195
for name , module in model .named_modules ():
180
196
# match until we get a full set
0 commit comments