Skip to content

Commit c69c8bb

Browse files
authored
[Utils] Support matching vLLM modules (#413)
* support matching vllm modules Signed-off-by: Kyle Sayers <[email protected]> * propagate argument Signed-off-by: Kyle Sayers <[email protected]> --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent 46d84d8 commit c69c8bb

File tree

2 files changed

+98
-14
lines changed

2 files changed

+98
-14
lines changed

src/compressed_tensors/utils/match.py

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import logging
1616
import re
1717
from collections.abc import Generator
18-
from typing import Iterable, Tuple
18+
from typing import Iterable, Mapping, Optional, Tuple
1919

2020
import torch
2121
from compressed_tensors.utils.internal import InternalModule
@@ -32,10 +32,14 @@
3232
]
3333

3434

35+
FusedMappping = Mapping[str, Iterable[str]]
36+
37+
3538
def match_named_modules(
3639
model: torch.nn.Module,
3740
targets: Iterable[str],
3841
ignore: Iterable[str] = tuple(),
42+
fused: Optional[FusedMappping] = None,
3943
warn_on_fail: bool = False,
4044
) -> Generator[Tuple[str, torch.nn.Module]]:
4145
"""
@@ -45,16 +49,18 @@ def match_named_modules(
4549
:param model: model containing submodules to match against
4650
:param targets: target strings, potentially containing "re:" prefixes
4751
:param ignore: targets to ignore, potentially containing "re:" prefixes
52+
:fused: optional mapping from suffixes of fused modules to the suffixes of their
53+
corresponding shards. See `compressed_tensors.utils.match.is_match`
4854
:param warn_on_fail: if True, warns if any targets do not match any modules in model
4955
:return: generator of module names and modules
5056
"""
5157
unmatched_targets = set(targets)
5258
for name, module in model.named_modules():
5359
for target in targets:
54-
if is_match(name, module, target):
60+
if is_match(name, module, target, fused):
5561
unmatched_targets -= {target}
5662

57-
if not any(is_match(name, module, ign) for ign in ignore):
63+
if not any(is_match(name, module, ign, fused) for ign in ignore):
5864
yield name, module
5965

6066
if warn_on_fail:
@@ -68,6 +74,7 @@ def match_named_parameters(
6874
model: torch.nn.Module,
6975
targets: Iterable[str],
7076
ignore: Iterable[str] = tuple(),
77+
fused: Optional[FusedMappping] = None,
7178
warn_on_fail: bool = False,
7279
) -> Generator[Tuple[str, torch.nn.Module, torch.nn.Parameter]]:
7380
"""
@@ -77,6 +84,8 @@ def match_named_parameters(
7784
:param model: model containing params to match against
7885
:param targets: target strings, potentially containing "re:" prefixes
7986
:param ignore: targets to ignore, potentially containing "re:" prefixes
87+
:fused: optional mapping from suffixes of fused modules to the suffixes of their
88+
corresponding shards. See `compressed_tensors.utils.match.is_match`
8089
:param warn_on_fail: if True, warns if any targets do not match any params in model
8190
:return: generator of fully-qualified param names, parent modules, and params
8291
"""
@@ -88,10 +97,10 @@ def match_named_parameters(
8897
for param_name, param in module.named_parameters(recurse=False):
8998
param_fqn = f"{module_name}.{param_name}"
9099
for target in targets:
91-
if _match_name(param_fqn, target):
100+
if _match_name(param_fqn, target, fused):
92101
unmatched_targets -= {target}
93102

94-
if not any(_match_name(param_fqn, ign) for ign in ignore):
103+
if not any(_match_name(param_fqn, ign, fused) for ign in ignore):
95104
yield param_fqn, module, param
96105

97106
if warn_on_fail:
@@ -164,21 +173,56 @@ def match_modules_set(
164173
raise ValueError(f"Unable to match targets into set: {unmatched_keys}")
165174

166175

167-
def is_match(name: str, module: torch.nn.Module, target: str) -> bool:
176+
def is_match(
177+
name: str,
178+
module: torch.nn.Module,
179+
target: str,
180+
fused: Optional[FusedMappping] = None,
181+
) -> bool:
168182
"""
169183
Returns true if either module name or module parent classes match against target
170-
and the module is not an internal module
184+
and the module is not an internal module. The name and module may refer to a fused
185+
module defined by vLLM. In these cases, a `fused` mapping must be provided.
186+
187+
For example, in `vllm/model_executor/models/llama.py`:
188+
```python
189+
packed_modules_mapping = {
190+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
191+
"gate_up_proj": ["gate_proj", "up_proj"]
192+
}
193+
```
194+
195+
:param name: name of module
196+
:param module: module to match
197+
:param target: target which matches name or module, potentially contains regex
198+
:fused: optional mapping from suffixes of fused modules to the suffixes of their
199+
corresponding shards
171200
"""
172201
return not isinstance(module, InternalModule) and (
173-
_match_name(name, target) or _match_class(module, target)
202+
_match_name(name, target, fused) or _match_class(module, target)
174203
)
175204

176205

177-
def _match_name(name: str, target: str) -> bool:
206+
def _match_name(name: str, target: str, fused: Optional[FusedMappping] = None) -> bool:
178207
"""
179-
Returns true if target string begins with "re:" and
180-
regex matches or if target string exactly matches name
208+
Returns true if target string begins with "re:" and regex matches or if target
209+
string exactly matches name. If the name refers to a fused module defined by vLLM,
210+
a `fused` mapping must be provided.
211+
212+
:param name: name of module
213+
:param target: target name, potentially contains regex
214+
:fused: optional mapping from suffixes of fused modules to the suffixes of their
215+
corresponding shards
181216
"""
217+
if fused is not None:
218+
for fused_suffix in fused:
219+
if name.endswith(fused_suffix):
220+
name_stripped = name.removesuffix(fused_suffix)
221+
return any(
222+
_match_name(name_stripped + shard_suffix, target)
223+
for shard_suffix in fused[fused_suffix]
224+
)
225+
182226
if target.startswith("re:"):
183227
return re.match(target.removeprefix("re:"), name) is not None
184228
else:
@@ -187,10 +231,20 @@ def _match_name(name: str, target: str) -> bool:
187231

188232
def _match_class(module: torch.nn.Module, target: str) -> bool:
189233
"""
190-
Returns true if any torch parent class names match the target string exactly
234+
Returns true if any torch parent class names match the target string exactly.
235+
A special exception is made for vllm's `LinearBase` class which matches `Linear`
236+
237+
:param module: module to match
238+
:param target: target which matches name or module
191239
"""
192240
# will never match against a regex pattern since `:` is not allowed in class names
193241
return any(
194-
issubclass(cls, torch.nn.Module) and cls.__name__ == target
242+
(
243+
issubclass(cls, torch.nn.Module)
244+
and (
245+
cls.__name__ == target
246+
or (cls.__name__ == "LinearBase" and target == "Linear")
247+
)
248+
)
195249
for cls in module.__class__.__mro__
196250
)

tests/test_utils/test_match.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import pytest
1818
import torch.nn as nn
19-
from accelerate import init_empty_weights
2019

2120
# Assuming the module is named "module_matching" - adjust import as needed
2221
from compressed_tensors.utils import (
@@ -33,6 +32,11 @@ class DummyModel(nn.Module):
3332
"""Test model for unit tests. Weights are initialized on meta device"""
3433

3534
def __init__(self):
35+
try:
36+
from accelerate import init_empty_weights
37+
except ImportError:
38+
pytest.skip("Skipping weight init requires accelerate")
39+
3640
super().__init__()
3741
with init_empty_weights():
3842
self.layer1 = nn.Linear(10, 20)
@@ -142,6 +146,15 @@ def test_custom_module(self):
142146
assert _match_class(model, "DummyModel") == True
143147
assert _match_class(model, "Module") == True
144148

149+
def test_linear_base(self):
150+
"""Test matching against vllm's LinearBase class"""
151+
152+
class LinearBase(nn.Module):
153+
pass
154+
155+
linear = LinearBase()
156+
assert _match_class(linear, "Linear") == True
157+
145158

146159
class TestIsMatch:
147160
"""Test cases for is_match function"""
@@ -180,6 +193,23 @@ class InternalLinear(InternalModule, nn.Linear):
180193
linear = InternalLinear(10, 20)
181194
assert is_match("layer1", linear, "re:layer.*") == False
182195

196+
def test_fused_mapping(self):
197+
""""""
198+
linear = nn.Linear(10, 20)
199+
mapping = {
200+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
201+
"gate_up_proj": ["gate_proj", "up_proj"],
202+
}
203+
204+
assert is_match("dummy.qkv_proj", linear, "re:.*q_proj", mapping) == True
205+
assert is_match("dummy.qkv_proj", linear, "re:.*k_proj", mapping) == True
206+
assert is_match("dummy.qkv_proj", linear, "re:.*v_proj", mapping) == True
207+
assert is_match("dummy.qkv_proj", linear, "Linear", mapping) == True
208+
209+
assert is_match("dummy.gate_up_proj", linear, "re:.*gate_proj", mapping) == True
210+
assert is_match("dummy.gate_up_proj", linear, "re:.*up_proj", mapping) == True
211+
assert is_match("dummy.gate_up_proj", linear, "Linear", mapping) == True
212+
183213

184214
class TestMatchNamedModules:
185215
"""Test cases for match_named_modules function"""

0 commit comments

Comments
 (0)