14
14
15
15
import logging
16
16
import re
17
+ from collections import OrderedDict
17
18
from collections .abc import Generator
18
19
from typing import Iterable , Tuple
19
20
23
24
_LOGGER : logging .Logger = logging .getLogger (__name__ )
24
25
25
26
26
- __all__ = ["match_named_modules" , "is_match" ]
27
+ __all__ = [
28
+ "match_named_modules" ,
29
+ "match_named_parameters" ,
30
+ "match_modules_set" ,
31
+ "is_match" ,
32
+ "match_name" ,
33
+ "match_class" ,
34
+ ]
27
35
28
36
29
37
def match_named_modules (
30
38
model : torch .nn .Module ,
31
39
targets : Iterable [str ] = tuple (),
32
40
ignore : Iterable [str ] = tuple (),
33
- warn_on_fail : bool = True ,
34
- ) -> Generator [Tuple [str , torch .nn .Module ], None , None ]:
41
+ warn_on_fail : bool = False ,
42
+ ) -> Generator [Tuple [str , torch .nn .Module ]]:
43
+ """
44
+ Yields names and modules which match `targets` but do not match `ignore`.
45
+ Values are returned in order of `model.named_modules()`
46
+
47
+ :param model: model containing submodules to match against
48
+ :param targets: target strings, potentially containing "re:" prefixes
49
+ :param ignore: targets to ignore, potentially containing "re:" prefixes
50
+ :param warn_on_fail: if True, warns if any targets do not match any modules in model
51
+ :return: generator of module names and modules
52
+ """
35
53
unmatched_targets = set (targets )
36
54
for name , module in model .named_modules ():
37
55
for target in targets :
38
56
if is_match (name , module , target ):
39
- unmatched_targets . remove ( target )
57
+ unmatched_targets -= { target }
40
58
41
59
if not any (is_match (name , module , ign ) for ign in ignore ):
42
60
yield name , module
@@ -48,22 +66,126 @@ def match_named_modules(
48
66
)
49
67
50
68
69
+ def match_named_parameters (
70
+ model : torch .nn .Module ,
71
+ targets : Iterable [str ],
72
+ ignore : Iterable [str ],
73
+ warn_on_fail : bool = False ,
74
+ ) -> Generator [Tuple [str , torch .nn .Module , torch .nn .Parameter ]]:
75
+ """
76
+ Yields parameters which match `targets` but do not match `ignore`.
77
+ Values are returned in order of `model.named_modules()`
78
+
79
+ :param model: model containing params to match against
80
+ :param targets: target strings, potentially containing "re:" prefixes
81
+ :param ignore: targets to ignore, potentially containing "re:" prefixes
82
+ :param warn_on_fail: if True, warns if any targets do not match any params in model
83
+ :return: generator of fully-qualified param names, parent modules, and params
84
+ """
85
+ unmatched_targets = set (targets )
86
+ for module_name , module in model .named_modules ():
87
+ for param_name , param in module .named_parameters (recurse = False ):
88
+ param_fqn = f"{ module_name } .{ param_name } "
89
+ for target in targets :
90
+ if match_name (param_fqn , target ):
91
+ unmatched_targets -= {target }
92
+
93
+ if not any (match_name (param_fqn , ign ) for ign in ignore ):
94
+ yield param_fqn , module , param
95
+
96
+ if warn_on_fail :
97
+ for target in unmatched_targets :
98
+ _LOGGER .warning (
99
+ f"Could not match `{ target } ` in instance of { model .__class__ .__name__ } "
100
+ )
101
+
102
+
103
+ def match_modules_set (
104
+ model : torch .nn .Module ,
105
+ targets : Iterable [str ],
106
+ ignore : Iterable [str ],
107
+ ) -> Generator [Iterable [torch .nn .Module ]]:
108
+ """
109
+ Yields modules grouped with the same order and size as `targets`.
110
+ Values are returned in order of `model.named_modules()`
111
+
112
+ For example, the following targets would yield module belonging to the following layers:
113
+ ```python3
114
+ match_modules_set(model, ["q_proj", "k_proj", "v_proj"]) == (
115
+ (
116
+ `model.layers.0.self_attn.q_proj`,
117
+ `model.layers.0.self_attn.k_proj`,
118
+ `model.layers.0.self_attn.v_proj`,
119
+ ),
120
+ (
121
+ `model.layers.1.self_attn.q_proj`,
122
+ `model.layers.1.self_attn.k_proj`,
123
+ `model.layers.1.self_attn.v_proj`,
124
+ ),
125
+ ...
126
+ (
127
+ `model.layers.32.self_attn.q_proj`,
128
+ `model.layers.32.self_attn.k_proj`,
129
+ `model.layers.32.self_attn.v_proj`,
130
+ ),
131
+ )
132
+ ```
133
+
134
+ This can be used to match layers to their corresponding downstream counterparts.
135
+ For example, matching layer norms to their subsequent linear layers
136
+ ```python3
137
+ for norm, q, k, v in match_modules_set(model, (norm_tgt, q_tgt, k_tgt, v_tgt)):
138
+ fuse_norm_linears(norm, [q, k, v])
139
+
140
+ :param model: model containing modules to match against
141
+ :param targets: target strings, potentially containing "re:" prefixes
142
+ :param ignore: targets to ignore, potentially containing "re:" prefixes
143
+ """
144
+ matches = dict .fromkeys (targets , None )
145
+ for name , module in model .named_modules ():
146
+ # match until we get a full set
147
+ for target in targets :
148
+ if is_match (name , module , target ) and not any (
149
+ is_match (name , module , ign ) for ign in ignore
150
+ ):
151
+ if matches [target ] is not None :
152
+ raise ValueError (f"Matched a { target } twice before completing set" )
153
+ matches [target ] = module
154
+
155
+ # once we have a full set, yield and reset
156
+ if all (matches [target ] is not None for target in targets ):
157
+ yield [matches [target ] for target in targets ] # ensure correct ordering
158
+ matches = dict .fromkeys (targets , None )
159
+
160
+ # check that none are left over
161
+ unmatched_keys = [match for match , value in matches .items () if value is None ]
162
+ if len (unmatched_keys ):
163
+ raise ValueError (f"Unable to match targets into set: { unmatched_keys } " )
164
+
165
+
51
166
def is_match (name : str , module : torch .nn .Module , target : str ) -> bool :
52
- return _match_name (name , target ) or _match_class (module , target )
167
+ """
168
+ Returns true if either module name or module parent classes match against target
169
+ """
170
+ return match_name (name , target ) or match_class (module , target )
53
171
54
172
55
- def _match_name (name : str , target : str ) -> bool :
173
+ def match_name (name : str , target : str ) -> bool :
174
+ """
175
+ Returns true if target string begins with "re:" and
176
+ regex matches or if target string exactly matches name
177
+ """
56
178
if target .startswith ("re:" ):
57
179
return re .match (target .removeprefix ("re:" ), name )
58
180
else :
59
181
return target == name
60
182
61
183
62
- def _match_class (module : torch .nn .Module , target : str ) -> bool :
184
+ def match_class (module : torch .nn .Module , target : str ) -> bool :
63
185
"""
64
- Will never match against a regex pattern since `:` is not allowed in class names
65
-
186
+ Returns true if any torch parent class names match the target string exactly
66
187
"""
188
+ # will never match against a regex pattern since `:` is not allowed in class names
67
189
return any (
68
190
issubclass (cls , torch .nn .Module ) and cls .__name__ == target
69
191
for cls in module .__class__ .__mro__
0 commit comments