15
15
import logging
16
16
import re
17
17
from collections .abc import Generator
18
- from typing import Iterable , Tuple
18
+ from typing import Iterable , Mapping , Optional , Tuple
19
19
20
20
import torch
21
21
from compressed_tensors .utils .internal import InternalModule
32
32
]
33
33
34
34
35
+ FusedMappping = Mapping [str , Iterable [str ]]
36
+
37
+
35
38
def match_named_modules (
36
39
model : torch .nn .Module ,
37
40
targets : Iterable [str ],
38
41
ignore : Iterable [str ] = tuple (),
42
+ fused : Optional [FusedMappping ] = None ,
39
43
warn_on_fail : bool = False ,
40
44
) -> Generator [Tuple [str , torch .nn .Module ]]:
41
45
"""
@@ -45,16 +49,18 @@ def match_named_modules(
45
49
:param model: model containing submodules to match against
46
50
:param targets: target strings, potentially containing "re:" prefixes
47
51
:param ignore: targets to ignore, potentially containing "re:" prefixes
52
+ :fused: optional mapping from suffixes of fused modules to the suffixes of their
53
+ corresponding shards. See `compressed_tensors.utils.match.is_match`
48
54
:param warn_on_fail: if True, warns if any targets do not match any modules in model
49
55
:return: generator of module names and modules
50
56
"""
51
57
unmatched_targets = set (targets )
52
58
for name , module in model .named_modules ():
53
59
for target in targets :
54
- if is_match (name , module , target ):
60
+ if is_match (name , module , target , fused ):
55
61
unmatched_targets -= {target }
56
62
57
- if not any (is_match (name , module , ign ) for ign in ignore ):
63
+ if not any (is_match (name , module , ign , fused ) for ign in ignore ):
58
64
yield name , module
59
65
60
66
if warn_on_fail :
@@ -68,6 +74,7 @@ def match_named_parameters(
68
74
model : torch .nn .Module ,
69
75
targets : Iterable [str ],
70
76
ignore : Iterable [str ] = tuple (),
77
+ fused : Optional [FusedMappping ] = None ,
71
78
warn_on_fail : bool = False ,
72
79
) -> Generator [Tuple [str , torch .nn .Module , torch .nn .Parameter ]]:
73
80
"""
@@ -77,6 +84,8 @@ def match_named_parameters(
77
84
:param model: model containing params to match against
78
85
:param targets: target strings, potentially containing "re:" prefixes
79
86
:param ignore: targets to ignore, potentially containing "re:" prefixes
87
+ :fused: optional mapping from suffixes of fused modules to the suffixes of their
88
+ corresponding shards. See `compressed_tensors.utils.match.is_match`
80
89
:param warn_on_fail: if True, warns if any targets do not match any params in model
81
90
:return: generator of fully-qualified param names, parent modules, and params
82
91
"""
@@ -88,10 +97,10 @@ def match_named_parameters(
88
97
for param_name , param in module .named_parameters (recurse = False ):
89
98
param_fqn = f"{ module_name } .{ param_name } "
90
99
for target in targets :
91
- if _match_name (param_fqn , target ):
100
+ if _match_name (param_fqn , target , fused ):
92
101
unmatched_targets -= {target }
93
102
94
- if not any (_match_name (param_fqn , ign ) for ign in ignore ):
103
+ if not any (_match_name (param_fqn , ign , fused ) for ign in ignore ):
95
104
yield param_fqn , module , param
96
105
97
106
if warn_on_fail :
@@ -164,21 +173,56 @@ def match_modules_set(
164
173
raise ValueError (f"Unable to match targets into set: { unmatched_keys } " )
165
174
166
175
167
- def is_match (name : str , module : torch .nn .Module , target : str ) -> bool :
176
+ def is_match (
177
+ name : str ,
178
+ module : torch .nn .Module ,
179
+ target : str ,
180
+ fused : Optional [FusedMappping ] = None ,
181
+ ) -> bool :
168
182
"""
169
183
Returns true if either module name or module parent classes match against target
170
- and the module is not an internal module
184
+ and the module is not an internal module. The name and module may refer to a fused
185
+ module defined by vLLM. In these cases, a `fused` mapping must be provided.
186
+
187
+ For example, in `vllm/model_executor/models/llama.py`:
188
+ ```python
189
+ packed_modules_mapping = {
190
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
191
+ "gate_up_proj": ["gate_proj", "up_proj"]
192
+ }
193
+ ```
194
+
195
+ :param name: name of module
196
+ :param module: module to match
197
+ :param target: target which matches name or module, potentially contains regex
198
+ :fused: optional mapping from suffixes of fused modules to the suffixes of their
199
+ corresponding shards
171
200
"""
172
201
return not isinstance (module , InternalModule ) and (
173
- _match_name (name , target ) or _match_class (module , target )
202
+ _match_name (name , target , fused ) or _match_class (module , target )
174
203
)
175
204
176
205
177
- def _match_name (name : str , target : str ) -> bool :
206
+ def _match_name (name : str , target : str , fused : Optional [ FusedMappping ] = None ) -> bool :
178
207
"""
179
- Returns true if target string begins with "re:" and
180
- regex matches or if target string exactly matches name
208
+ Returns true if target string begins with "re:" and regex matches or if target
209
+ string exactly matches name. If the name refers to a fused module defined by vLLM,
210
+ a `fused` mapping must be provided.
211
+
212
+ :param name: name of module
213
+ :param target: target name, potentially contains regex
214
+ :fused: optional mapping from suffixes of fused modules to the suffixes of their
215
+ corresponding shards
181
216
"""
217
+ if fused is not None :
218
+ for fused_suffix in fused :
219
+ if name .endswith (fused_suffix ):
220
+ name_stripped = name .removesuffix (fused_suffix )
221
+ return any (
222
+ _match_name (name_stripped + shard_suffix , target )
223
+ for shard_suffix in fused [fused_suffix ]
224
+ )
225
+
182
226
if target .startswith ("re:" ):
183
227
return re .match (target .removeprefix ("re:" ), name ) is not None
184
228
else :
@@ -187,10 +231,20 @@ def _match_name(name: str, target: str) -> bool:
187
231
188
232
def _match_class (module : torch .nn .Module , target : str ) -> bool :
189
233
"""
190
- Returns true if any torch parent class names match the target string exactly
234
+ Returns true if any torch parent class names match the target string exactly.
235
+ A special exception is made for vllm's `LinearBase` class which matches `Linear`
236
+
237
+ :param module: module to match
238
+ :param target: target which matches name or module
191
239
"""
192
240
# will never match against a regex pattern since `:` is not allowed in class names
193
241
return any (
194
- issubclass (cls , torch .nn .Module ) and cls .__name__ == target
242
+ (
243
+ issubclass (cls , torch .nn .Module )
244
+ and (
245
+ cls .__name__ == target
246
+ or (cls .__name__ == "LinearBase" and target == "Linear" )
247
+ )
248
+ )
195
249
for cls in module .__class__ .__mro__
196
250
)
0 commit comments