Skip to content

Commit 0e5df88

Browse files
TransformScheme.block_size, deprecate head_dim (#466)
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent de9cfc6 commit 0e5df88

File tree

5 files changed

+47
-6
lines changed

5 files changed

+47
-6
lines changed

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def create_transform(self, module: Module, args: TransformArgs):
5252
:param args: defines how the transform will be applied to the module
5353
"""
5454
assert hasattr(module, "weight")
55-
size = get_transform_size(module, args.location, self.scheme.head_dim)
55+
size = get_transform_size(module, args.location, self.scheme.block_size)
5656
exec_device = get_execution_device(module)
5757
device = get_offloaded_device(module)
5858
precision = self.scheme.precision if args.is_online() else torch.float64

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def create_transform(self, module: Module, args: TransformArgs):
5151
:param args: defines how the transform will be applied to the module
5252
"""
5353
assert hasattr(module, "weight")
54-
size = get_transform_size(module, args.location, self.scheme.head_dim)
54+
size = get_transform_size(module, args.location, self.scheme.block_size)
5555
device = get_offloaded_device(module)
5656
precision = self.scheme.precision if args.is_online() else torch.float64
5757

src/compressed_tensors/transform/transform_scheme.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch
1818
from compressed_tensors.transform import TransformArgs
1919
from compressed_tensors.utils import TorchDtype
20-
from pydantic import BaseModel, ConfigDict, Field
20+
from pydantic import BaseModel, ConfigDict, Field, model_validator
2121

2222

2323
__all__ = ["TransformScheme"]
@@ -36,6 +36,8 @@ class TransformScheme(BaseModel):
3636
:param randomize: True if uniquely randomized transform weights should be used,
3737
otherwise use identical transform weights where applicable
3838
:param requires_grad: True if weights include gradients for training
39+
:param block_size: If set, the transform matrix will be block diagonal, with each
40+
block being a square matrix of this size.
3941
:param precision: Precision at which this transform should be applied during online
4042
rotations. Fused (offline) rotations are always performed in float64
4143
"""
@@ -44,7 +46,21 @@ class TransformScheme(BaseModel):
4446
apply: List[TransformArgs] = Field(default_factory=list)
4547
randomize: bool = Field(default=False)
4648
requires_grad: bool = Field(default=False)
47-
head_dim: Optional[int] = Field(default=None)
49+
block_size: Optional[int] = Field(default=None)
50+
head_dim: Optional[int] = Field(
51+
default=None, deprecated="head_dim is deprecated, use block_size instead"
52+
)
4853
precision: TorchDtype = Field(default=torch.float32)
4954

55+
@model_validator(mode="after")
56+
def validate_model_after(model: "TransformScheme") -> "TransformScheme":
57+
"""
58+
If head_dim is used instead of block_size, set block_size to head_dim
59+
and remove head_dim
60+
"""
61+
if model.block_size is None and model.head_dim is not None:
62+
model.block_size = model.head_dim
63+
model.head_dim = None
64+
return model
65+
5066
model_config = ConfigDict(extra="forbid")

tests/test_transform/factory/test_correctness.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
def test_correctness_linear(type, randomize, head_dim, input_batch_size):
3434
size = (4, 8)
3535
module = torch.nn.Linear(*size, bias=False)
36-
scheme = TransformScheme(type=type, randomize=randomize, head_dim=head_dim)
36+
scheme = TransformScheme(type=type, randomize=randomize, block_size=head_dim)
3737
factory = TransformFactory.from_scheme(scheme, name="")
3838

3939
input_tfm = factory.create_transform(
@@ -150,7 +150,7 @@ def test_correctness_attention_heads(type, randomize, head_dim, input_batch_size
150150
"": TransformScheme(
151151
type=type,
152152
randomize=randomize,
153-
head_dim=head_dim,
153+
block_size=head_dim,
154154
apply=[
155155
TransformArgs(targets="v_proj", location="weight_output"),
156156
TransformArgs(

tests/test_transform/test_transform_scheme.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,28 @@ def test_multiple_groups():
7272
assert not scheme.randomize
7373
assert scheme.type == "hadamard"
7474
assert len(scheme.apply) == 20
75+
76+
77+
def test_transform_scheme_block_size():
78+
"""
79+
Ensure json with (deprecated) `head_dim` or `block_size`
80+
both load up correctly and save with `block_size` field
81+
"""
82+
83+
old_scheme = TransformScheme.model_validate_json(
84+
'{"type": "hadamard", "head_dim": 128}'
85+
)
86+
assert old_scheme.block_size == 128
87+
assert old_scheme.model_dump()["block_size"] == 128
88+
old_scheme = TransformScheme(type="hadamard", head_dim=64)
89+
assert old_scheme.block_size == 64
90+
assert old_scheme.model_dump()["block_size"] == 64
91+
92+
new_scheme = TransformScheme.model_validate_json(
93+
'{"type": "hadamard", "block_size": 128}'
94+
)
95+
assert new_scheme.block_size == 128
96+
assert new_scheme.model_dump()["block_size"] == 128
97+
new_scheme = TransformScheme(type="hadamard", block_size=64)
98+
assert new_scheme.block_size == 64
99+
assert new_scheme.model_dump()["block_size"] == 64

0 commit comments

Comments
 (0)