Skip to content

Commit a55c1bc

Browse files
[Transform] Revert deprecation of TransformScheme.head_dim for compatibility with vllm (#472)
* allow for use of head_dim for vllm Signed-off-by: Brian Dellabetta <[email protected]> * cleanup Signed-off-by: Brian Dellabetta <[email protected]> * revert TransformScheme.block_size Signed-off-by: Brian Dellabetta <[email protected]> * test fixes Signed-off-by: Brian Dellabetta <[email protected]> * docstring update Signed-off-by: Brian Dellabetta <[email protected]> --------- Signed-off-by: Brian Dellabetta <[email protected]>
1 parent 00c544f commit a55c1bc

File tree

5 files changed

+11
-47
lines changed

5 files changed

+11
-47
lines changed

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def create_transform(self, module: Module, args: TransformArgs):
5252
:param args: defines how the transform will be applied to the module
5353
"""
5454
assert hasattr(module, "weight")
55-
size = get_transform_size(module, args.location, self.scheme.block_size)
55+
size = get_transform_size(module, args.location, self.scheme.head_dim)
5656
exec_device = get_execution_device(module)
5757
device = get_offloaded_device(module)
5858
precision = self.scheme.precision if args.is_online() else torch.float64

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def create_transform(self, module: Module, args: TransformArgs):
5151
:param args: defines how the transform will be applied to the module
5252
"""
5353
assert hasattr(module, "weight")
54-
size = get_transform_size(module, args.location, self.scheme.block_size)
54+
size = get_transform_size(module, args.location, self.scheme.head_dim)
5555
device = get_offloaded_device(module)
5656
precision = self.scheme.precision if args.is_online() else torch.float64
5757

src/compressed_tensors/transform/transform_scheme.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch
1818
from compressed_tensors.transform import TransformArgs
1919
from compressed_tensors.utils import TorchDtype
20-
from pydantic import BaseModel, ConfigDict, Field, model_validator
20+
from pydantic import BaseModel, ConfigDict, Field
2121

2222

2323
__all__ = ["TransformScheme"]
@@ -36,8 +36,11 @@ class TransformScheme(BaseModel):
3636
:param randomize: True if uniquely randomized transform weights should be used,
3737
otherwise use identical transform weights where applicable
3838
:param requires_grad: True if weights include gradients for training
39-
:param block_size: If set, the transform matrix will be block diagonal, with each
40-
block being a square matrix of this size.
39+
:param head_dim: If set, the transform matrix will be block diagonal with each
40+
block being a square matrix of this size. The name head_dim was chosen because
41+
some rotations need to be block-diagonal with block size equal to the head_dim,
42+
but research has shown value in applying some rotations with smaller block size,
43+
irrespective of head_dim.
4144
:param precision: Precision at which this transform should be applied during online
4245
rotations. Fused (offline) rotations are always performed in float64
4346
"""
@@ -46,21 +49,7 @@ class TransformScheme(BaseModel):
4649
apply: List[TransformArgs] = Field(default_factory=list)
4750
randomize: bool = Field(default=False)
4851
requires_grad: bool = Field(default=False)
49-
block_size: Optional[int] = Field(default=None)
50-
head_dim: Optional[int] = Field(
51-
default=None, deprecated="head_dim is deprecated, use block_size instead"
52-
)
52+
head_dim: Optional[int] = Field(default=None)
5353
precision: TorchDtype = Field(default=torch.float32)
5454

55-
@model_validator(mode="after")
56-
def validate_model_after(model: "TransformScheme") -> "TransformScheme":
57-
"""
58-
If head_dim is used instead of block_size, set block_size to head_dim
59-
and remove head_dim
60-
"""
61-
if model.block_size is None and model.head_dim is not None:
62-
model.block_size = model.head_dim
63-
model.head_dim = None
64-
return model
65-
6655
model_config = ConfigDict(extra="forbid")

tests/test_transform/factory/test_correctness.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
def test_correctness_linear(type, randomize, head_dim, input_batch_size):
3434
size = (4, 8)
3535
module = torch.nn.Linear(*size, bias=False)
36-
scheme = TransformScheme(type=type, randomize=randomize, block_size=head_dim)
36+
scheme = TransformScheme(type=type, randomize=randomize, head_dim=head_dim)
3737
factory = TransformFactory.from_scheme(scheme, name="")
3838

3939
input_tfm = factory.create_transform(
@@ -150,7 +150,7 @@ def test_correctness_attention_heads(type, randomize, head_dim, input_batch_size
150150
"": TransformScheme(
151151
type=type,
152152
randomize=randomize,
153-
block_size=head_dim,
153+
head_dim=head_dim,
154154
apply=[
155155
TransformArgs(targets="v_proj", location="weight_output"),
156156
TransformArgs(

tests/test_transform/test_transform_scheme.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -72,28 +72,3 @@ def test_multiple_groups():
7272
assert not scheme.randomize
7373
assert scheme.type == "hadamard"
7474
assert len(scheme.apply) == 20
75-
76-
77-
def test_transform_scheme_block_size():
78-
"""
79-
Ensure json with (deprecated) `head_dim` or `block_size`
80-
both load up correctly and save with `block_size` field
81-
"""
82-
83-
old_scheme = TransformScheme.model_validate_json(
84-
'{"type": "hadamard", "head_dim": 128}'
85-
)
86-
assert old_scheme.block_size == 128
87-
assert old_scheme.model_dump()["block_size"] == 128
88-
old_scheme = TransformScheme(type="hadamard", head_dim=64)
89-
assert old_scheme.block_size == 64
90-
assert old_scheme.model_dump()["block_size"] == 64
91-
92-
new_scheme = TransformScheme.model_validate_json(
93-
'{"type": "hadamard", "block_size": 128}'
94-
)
95-
assert new_scheme.block_size == 128
96-
assert new_scheme.model_dump()["block_size"] == 128
97-
new_scheme = TransformScheme(type="hadamard", block_size=64)
98-
assert new_scheme.block_size == 64
99-
assert new_scheme.model_dump()["block_size"] == 64

0 commit comments

Comments
 (0)