Skip to content

Commit 684db8b

Browse files
committed
Merge branch 'kylesayrs/transform-precision' into kylesayrs/transform-merge
2 parents dcefc0b + 5db0e13 commit 684db8b

File tree

5 files changed

+116
-28
lines changed

5 files changed

+116
-28
lines changed

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import math
1615
from typing import Optional
1716

1817
import torch
@@ -54,15 +53,14 @@ def create_transform(self, module: Module, args: TransformArgs):
5453
"""
5554
assert hasattr(module, "weight")
5655
size = get_transform_size(module, args.location, self.scheme.head_dim)
57-
dtype = module.weight.dtype
56+
dtype = self.scheme.precision
5857
device = get_offloaded_device(module)
5958
exec_device = get_execution_device(module)
60-
precision = self.scheme.precision
6159

6260
factory_kwargs = {"construct_device": exec_device}
6361
weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
6462
perm = self.perms[weight] if self.scheme.randomize else None
65-
return HadamardTransform(weight, perm, args, precision, type(module))
63+
return HadamardTransform(weight, perm, self.scheme, args, type(module))
6664

6765
def _create_weight(
6866
self,
@@ -86,17 +84,19 @@ def __init__(
8684
self,
8785
weight: Parameter,
8886
perm: Optional[Parameter],
87+
scheme: TransformScheme,
8988
args: TransformArgs,
9089
precision: torch.dtype,
9190
module_type: type[torch.nn.Module],
9291
):
9392
super().__init__()
9493
self.weight = weight
9594
self.perm = perm
95+
self.scheme = scheme
9696
self.args = args
9797
self.precision = precision
9898
self.module_type = module_type
99-
self._scale = torch.tensor(weight.size(0), dtype=self.precision).sqrt()
99+
self._scale = torch.tensor(weight.size(0), dtype=self.scheme.precision).sqrt()
100100

101101
def forward(self, value: Tensor) -> Tensor:
102102
weight = self.weight
@@ -109,8 +109,8 @@ def forward(self, value: Tensor) -> Tensor:
109109

110110
return (
111111
apply_transform_weight(
112-
weight.to(self.precision),
113-
value.to(self.precision),
112+
weight.to(self.scheme.precision),
113+
value.to(self.scheme.precision),
114114
self.args.location,
115115
self.module_type,
116116
)

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from compressed_tensors.utils import get_offloaded_device
2525
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
2626
from torch import Tensor, device, dtype
27-
from torch.nn import Linear, Module, Parameter
27+
from torch.nn import Module, Parameter
2828

2929

3030
@TransformFactory.register("random-matrix")
@@ -52,7 +52,7 @@ def create_transform(self, module: Module, args: TransformArgs):
5252
"""
5353
assert hasattr(module, "weight")
5454
size = get_transform_size(module, args.location, self.scheme.head_dim)
55-
dtype = module.weight.dtype
55+
dtype = self.scheme.precision
5656
device = get_offloaded_device(module)
5757
precision = self.scheme.precision
5858

@@ -79,29 +79,31 @@ class RandomMatrixTransform(TransformBase):
7979
def __init__(
8080
self,
8181
weight: Tensor,
82+
scheme: TransformScheme,
8283
args: TransformArgs,
8384
precision: torch.dtype,
8485
module_type: type[torch.nn.Module],
8586
):
8687
super().__init__()
8788
self.weight = weight # is an inverse if args.inverse
89+
self.scheme = scheme
8890
self.args = args
8991
self.precision = precision
9092
self.module_type = module_type
9193

9294
def forward(self, value: Tensor) -> Parameter:
9395
return apply_transform_weight(
94-
self.weight.to(self.precision),
95-
value.to(self.precision),
96+
self.weight.to(self.scheme.precision),
97+
value.to(self.scheme.precision),
9698
self.args.location,
9799
self.module_type,
98100
).to(value.dtype)
99101

100102
def right_inverse(self, value: Tensor) -> Tensor:
101103
inverse = high_precision_invert(self.weight)
102104
return apply_transform_weight(
103-
inverse.to(self.precision),
104-
value.to(self.precision),
105+
inverse.to(self.scheme.precision),
106+
value.to(self.scheme.precision),
105107
self.args.location,
106108
self.module_type,
107109
).to(value.dtype)

src/compressed_tensors/transform/transform_scheme.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,13 @@ class TransformScheme(BaseModel):
3636
:param randomize: True if uniquely randomized transform weights should be used,
3737
otherwise use identical transform weights where applicable
3838
:param requires_grad: True if weights include gradients for training
39+
:param precision: Precision at which this transform should be applied. This applies
40+
to both weight fusing and online rotations
3941
"""
4042

4143
type: str
4244
apply: List[TransformArgs] = Field(default_factory=list)
4345
randomize: bool = Field(default=False)
4446
requires_grad: bool = Field(default=False)
4547
head_dim: Optional[int] = Field(default=None)
46-
precision: TorchDtype = Field(default=torch.bfloat16)
48+
precision: TorchDtype = Field(default=torch.float32)

src/compressed_tensors/utils/match.py

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import logging
1616
import re
1717
from collections.abc import Generator
18-
from typing import Iterable, Tuple
18+
from typing import Iterable, Mapping, Optional, Tuple
1919

2020
import torch
2121
from compressed_tensors.utils.internal import InternalModule
@@ -32,10 +32,14 @@
3232
]
3333

3434

35+
FusedMappping = Mapping[str, Iterable[str]]
36+
37+
3538
def match_named_modules(
3639
model: torch.nn.Module,
3740
targets: Iterable[str],
3841
ignore: Iterable[str] = tuple(),
42+
fused: Optional[FusedMappping] = None,
3943
warn_on_fail: bool = False,
4044
) -> Generator[Tuple[str, torch.nn.Module]]:
4145
"""
@@ -45,16 +49,18 @@ def match_named_modules(
4549
:param model: model containing submodules to match against
4650
:param targets: target strings, potentially containing "re:" prefixes
4751
:param ignore: targets to ignore, potentially containing "re:" prefixes
52+
:fused: optional mapping from suffixes of fused modules to the suffixes of their
53+
corresponding shards. See `compressed_tensors.utils.match.is_match`
4854
:param warn_on_fail: if True, warns if any targets do not match any modules in model
4955
:return: generator of module names and modules
5056
"""
5157
unmatched_targets = set(targets)
5258
for name, module in model.named_modules():
5359
for target in targets:
54-
if is_match(name, module, target):
60+
if is_match(name, module, target, fused):
5561
unmatched_targets -= {target}
5662

57-
if not any(is_match(name, module, ign) for ign in ignore):
63+
if not any(is_match(name, module, ign, fused) for ign in ignore):
5864
yield name, module
5965

6066
if warn_on_fail:
@@ -68,6 +74,7 @@ def match_named_parameters(
6874
model: torch.nn.Module,
6975
targets: Iterable[str],
7076
ignore: Iterable[str] = tuple(),
77+
fused: Optional[FusedMappping] = None,
7178
warn_on_fail: bool = False,
7279
) -> Generator[Tuple[str, torch.nn.Module, torch.nn.Parameter]]:
7380
"""
@@ -77,6 +84,8 @@ def match_named_parameters(
7784
:param model: model containing params to match against
7885
:param targets: target strings, potentially containing "re:" prefixes
7986
:param ignore: targets to ignore, potentially containing "re:" prefixes
87+
:fused: optional mapping from suffixes of fused modules to the suffixes of their
88+
corresponding shards. See `compressed_tensors.utils.match.is_match`
8089
:param warn_on_fail: if True, warns if any targets do not match any params in model
8190
:return: generator of fully-qualified param names, parent modules, and params
8291
"""
@@ -88,10 +97,10 @@ def match_named_parameters(
8897
for param_name, param in module.named_parameters(recurse=False):
8998
param_fqn = f"{module_name}.{param_name}"
9099
for target in targets:
91-
if _match_name(param_fqn, target):
100+
if _match_name(param_fqn, target, fused):
92101
unmatched_targets -= {target}
93102

94-
if not any(_match_name(param_fqn, ign) for ign in ignore):
103+
if not any(_match_name(param_fqn, ign, fused) for ign in ignore):
95104
yield param_fqn, module, param
96105

97106
if warn_on_fail:
@@ -164,21 +173,56 @@ def match_modules_set(
164173
raise ValueError(f"Unable to match targets into set: {unmatched_keys}")
165174

166175

167-
def is_match(name: str, module: torch.nn.Module, target: str) -> bool:
176+
def is_match(
177+
name: str,
178+
module: torch.nn.Module,
179+
target: str,
180+
fused: Optional[FusedMappping] = None,
181+
) -> bool:
168182
"""
169183
Returns true if either module name or module parent classes match against target
170-
and the module is not an internal module
184+
and the module is not an internal module. The name and module may refer to a fused
185+
module defined by vLLM. In these cases, a `fused` mapping must be provided.
186+
187+
For example, in `vllm/model_executor/models/llama.py`:
188+
```python
189+
packed_modules_mapping = {
190+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
191+
"gate_up_proj": ["gate_proj", "up_proj"]
192+
}
193+
```
194+
195+
:param name: name of module
196+
:param module: module to match
197+
:param target: target which matches name or module, potentially contains regex
198+
:fused: optional mapping from suffixes of fused modules to the suffixes of their
199+
corresponding shards
171200
"""
172201
return not isinstance(module, InternalModule) and (
173-
_match_name(name, target) or _match_class(module, target)
202+
_match_name(name, target, fused) or _match_class(module, target)
174203
)
175204

176205

177-
def _match_name(name: str, target: str) -> bool:
206+
def _match_name(name: str, target: str, fused: Optional[FusedMappping] = None) -> bool:
178207
"""
179-
Returns true if target string begins with "re:" and
180-
regex matches or if target string exactly matches name
208+
Returns true if target string begins with "re:" and regex matches or if target
209+
string exactly matches name. If the name refers to a fused module defined by vLLM,
210+
a `fused` mapping must be provided.
211+
212+
:param name: name of module
213+
:param target: target name, potentially contains regex
214+
:fused: optional mapping from suffixes of fused modules to the suffixes of their
215+
corresponding shards
181216
"""
217+
if fused is not None:
218+
for fused_suffix in fused:
219+
if name.endswith(fused_suffix):
220+
name_stripped = name.removesuffix(fused_suffix)
221+
return any(
222+
_match_name(name_stripped + shard_suffix, target)
223+
for shard_suffix in fused[fused_suffix]
224+
)
225+
182226
if target.startswith("re:"):
183227
return re.match(target.removeprefix("re:"), name) is not None
184228
else:
@@ -187,10 +231,20 @@ def _match_name(name: str, target: str) -> bool:
187231

188232
def _match_class(module: torch.nn.Module, target: str) -> bool:
189233
"""
190-
Returns true if any torch parent class names match the target string exactly
234+
Returns true if any torch parent class names match the target string exactly.
235+
A special exception is made for vllm's `LinearBase` class which matches `Linear`
236+
237+
:param module: module to match
238+
:param target: target which matches name or module
191239
"""
192240
# will never match against a regex pattern since `:` is not allowed in class names
193241
return any(
194-
issubclass(cls, torch.nn.Module) and cls.__name__ == target
242+
(
243+
issubclass(cls, torch.nn.Module)
244+
and (
245+
cls.__name__ == target
246+
or (cls.__name__ == "LinearBase" and target == "Linear")
247+
)
248+
)
195249
for cls in module.__class__.__mro__
196250
)

tests/test_utils/test_match.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import pytest
1818
import torch.nn as nn
19-
from accelerate import init_empty_weights
2019

2120
# Assuming the module is named "module_matching" - adjust import as needed
2221
from compressed_tensors.utils import (
@@ -33,6 +32,11 @@ class DummyModel(nn.Module):
3332
"""Test model for unit tests. Weights are initialized on meta device"""
3433

3534
def __init__(self):
35+
try:
36+
from accelerate import init_empty_weights
37+
except ImportError:
38+
pytest.skip("Skipping weight init requires accelerate")
39+
3640
super().__init__()
3741
with init_empty_weights():
3842
self.layer1 = nn.Linear(10, 20)
@@ -142,6 +146,15 @@ def test_custom_module(self):
142146
assert _match_class(model, "DummyModel") == True
143147
assert _match_class(model, "Module") == True
144148

149+
def test_linear_base(self):
150+
"""Test matching against vllm's LinearBase class"""
151+
152+
class LinearBase(nn.Module):
153+
pass
154+
155+
linear = LinearBase()
156+
assert _match_class(linear, "Linear") == True
157+
145158

146159
class TestIsMatch:
147160
"""Test cases for is_match function"""
@@ -180,6 +193,23 @@ class InternalLinear(InternalModule, nn.Linear):
180193
linear = InternalLinear(10, 20)
181194
assert is_match("layer1", linear, "re:layer.*") == False
182195

196+
def test_fused_mapping(self):
197+
""""""
198+
linear = nn.Linear(10, 20)
199+
mapping = {
200+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
201+
"gate_up_proj": ["gate_proj", "up_proj"],
202+
}
203+
204+
assert is_match("dummy.qkv_proj", linear, "re:.*q_proj", mapping) == True
205+
assert is_match("dummy.qkv_proj", linear, "re:.*k_proj", mapping) == True
206+
assert is_match("dummy.qkv_proj", linear, "re:.*v_proj", mapping) == True
207+
assert is_match("dummy.qkv_proj", linear, "Linear", mapping) == True
208+
209+
assert is_match("dummy.gate_up_proj", linear, "re:.*gate_proj", mapping) == True
210+
assert is_match("dummy.gate_up_proj", linear, "re:.*up_proj", mapping) == True
211+
assert is_match("dummy.gate_up_proj", linear, "Linear", mapping) == True
212+
183213

184214
class TestMatchNamedModules:
185215
"""Test cases for match_named_modules function"""

0 commit comments

Comments
 (0)