Skip to content

Commit bf74240

Browse files
committed
Merge remote-tracking branch 'origin' into kylesayrs/forbid-extra
2 parents 331cb3f + 131673e commit bf74240

File tree

9 files changed

+220
-27
lines changed

9 files changed

+220
-27
lines changed

pyproject.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,12 @@ build-backend = "setuptools.build_meta"
55
[tool.black]
66
line-length = 88
77
target-version = ['py36']
8+
9+
[tool.pytest.ini_options]
10+
markers = [
11+
"unit: tests to ensure code correctness and regression test functionality",
12+
"smoke: quick tests to check basic functionality",
13+
"sanity: tests to ensure that new changes do not break existing functionality",
14+
"regression: detailed tests to ensure major functions work correctly",
15+
"integration: tests which integrate with a third party service such as HF",
16+
]

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import math
16-
from typing import Optional, Union
15+
from typing import Optional
1716

1817
import torch
1918
from compressed_tensors.transform import TransformArgs, TransformScheme
@@ -26,7 +25,7 @@
2625
from compressed_tensors.utils import get_execution_device, get_offloaded_device
2726
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
2827
from torch import Tensor, device, dtype
29-
from torch.nn import Linear, Module, Parameter
28+
from torch.nn import Module, Parameter
3029

3130

3231
@TransformFactory.register("hadamard")
@@ -54,14 +53,14 @@ def create_transform(self, module: Module, args: TransformArgs):
5453
"""
5554
assert hasattr(module, "weight")
5655
size = get_transform_size(module, args.location, self.scheme.head_dim)
57-
dtype = module.weight.dtype
56+
dtype = self.scheme.precision
5857
device = get_offloaded_device(module)
5958
exec_device = get_execution_device(module)
6059

6160
factory_kwargs = {"construct_device": exec_device}
6261
weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
6362
perm = self.perms[weight] if self.scheme.randomize else None
64-
return HadamardTransform(weight, perm, args, type(module))
63+
return HadamardTransform(weight, perm, self.scheme, args, type(module))
6564

6665
def _create_weight(
6766
self,
@@ -85,15 +84,18 @@ def __init__(
8584
self,
8685
weight: Parameter,
8786
perm: Optional[Parameter],
87+
scheme: TransformScheme,
8888
args: TransformArgs,
8989
module_type: type[torch.nn.Module],
9090
):
9191
super().__init__()
9292
self.weight = weight
9393
self.perm = perm
94+
self.scheme = scheme
9495
self.args = args
9596
self.module_type = module_type
96-
self._scale = math.sqrt(weight.size(0))
97+
self._scale = torch.tensor(weight.size(0), dtype=self.scheme.precision).sqrt()
98+
self._precision = scheme.precision if args.is_online() else torch.float64
9799

98100
def forward(self, value: Tensor) -> Tensor:
99101
weight = self.weight
@@ -105,6 +107,11 @@ def forward(self, value: Tensor) -> Tensor:
105107
weight = weight.T
106108

107109
return (
108-
apply_transform_weight(weight, value, self.args.location, self.module_type)
110+
apply_transform_weight(
111+
weight.to(self._precision),
112+
value.to(self._precision),
113+
self.args.location,
114+
self.module_type,
115+
)
109116
/ self._scale
110-
)
117+
).to(value.dtype)

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from compressed_tensors.utils import get_offloaded_device
2525
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
2626
from torch import Tensor, device, dtype
27-
from torch.nn import Linear, Module, Parameter
27+
from torch.nn import Module, Parameter
2828

2929

3030
@TransformFactory.register("random-matrix")
@@ -52,14 +52,14 @@ def create_transform(self, module: Module, args: TransformArgs):
5252
"""
5353
assert hasattr(module, "weight")
5454
size = get_transform_size(module, args.location, self.scheme.head_dim)
55-
dtype = module.weight.dtype
55+
dtype = self.scheme.precision
5656
device = get_offloaded_device(module)
5757

5858
weight = self.weights[size, dtype, device]
5959
if args.inverse:
6060
weight = self.inverses[weight]
6161

62-
return RandomMatrixTransform(weight, args, type(module))
62+
return RandomMatrixTransform(weight, self.scheme, args, type(module))
6363

6464
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
6565
# TODO: verify that weight is invertible (has non-zero determinant)
@@ -78,25 +78,34 @@ class RandomMatrixTransform(TransformBase):
7878
def __init__(
7979
self,
8080
weight: Tensor,
81+
scheme: TransformScheme,
8182
args: TransformArgs,
8283
module_type: type[torch.nn.Module],
8384
):
8485
super().__init__()
8586
self.weight = weight # is an inverse if args.inverse
87+
self.scheme = scheme
8688
self.args = args
8789
self.module_type = module_type
90+
self._precision = scheme.precision if args.is_online() else torch.float64
8891

8992
def forward(self, value: Tensor) -> Parameter:
9093
return apply_transform_weight(
91-
self.weight, value, self.args.location, self.module_type
92-
)
94+
self.weight.to(self._precision),
95+
value.to(self._precision),
96+
self.args.location,
97+
self.module_type,
98+
).to(value.dtype)
9399

94100
def right_inverse(self, value: Tensor) -> Tensor:
95101
inverse = high_precision_invert(self.weight)
96102
return apply_transform_weight(
97-
inverse, value, self.args.location, self.module_type
98-
)
103+
inverse.to(self._precision),
104+
value.to(self._precision),
105+
self.args.location,
106+
self.module_type,
107+
).to(value.dtype)
99108

100109

101110
def high_precision_invert(weight: Tensor) -> Tensor:
102-
return torch.linalg.inv(weight.to(torch.float32)).to(weight.dtype)
111+
return torch.linalg.inv(weight.to(torch.float64)).to(weight.dtype)

src/compressed_tensors/transform/transform_args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,4 +69,10 @@ def wrap_singleton(cls, value):
6969
return [value]
7070
return value
7171

72+
def is_online(self) -> bool:
73+
return self.location not in (
74+
TransformLocation.WEIGHT_INPUT,
75+
TransformLocation.WEIGHT_OUTPUT,
76+
)
77+
7278
model_config = ConfigDict(extra="forbid")

src/compressed_tensors/transform/transform_scheme.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
from typing import List, Optional
1616

17+
import torch
1718
from compressed_tensors.transform import TransformArgs
19+
from compressed_tensors.utils import TorchDtype
1820
from pydantic import BaseModel, ConfigDict, Field
1921

2022

@@ -34,12 +36,15 @@ class TransformScheme(BaseModel):
3436
:param randomize: True if uniquely randomized transform weights should be used,
3537
otherwise use identical transform weights where applicable
3638
:param requires_grad: True if weights include gradients for training
39+
:param precision: Precision at which this transform should be applied during online
40+
rotations. Fused (offline) rotations are always performed in float64
3741
"""
3842

3943
type: str
4044
apply: List[TransformArgs] = Field(default_factory=list)
4145
randomize: bool = Field(default=False)
4246
requires_grad: bool = Field(default=False)
4347
head_dim: Optional[int] = Field(default=None)
48+
precision: TorchDtype = Field(default=torch.float32)
4449

4550
model_config = ConfigDict(extra="forbid")

src/compressed_tensors/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@
2121
from .permute import *
2222
from .safetensors_load import *
2323
from .semi_structured_conversions import *
24+
from .type import *

src/compressed_tensors/utils/type.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Annotated, Any
16+
17+
import torch
18+
from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler
19+
from pydantic.json_schema import JsonSchemaValue
20+
from pydantic_core import core_schema
21+
22+
23+
__all__ = ["TorchDtype"]
24+
25+
26+
class _TorchDtypeAnnotation:
27+
@classmethod
28+
def __get_pydantic_core_schema__(
29+
cls,
30+
_source_type: Any,
31+
_handler: GetCoreSchemaHandler,
32+
) -> core_schema.CoreSchema:
33+
# support strings of the form `torch.xxx` or `xxx`
34+
def validate_from_str(name: str) -> torch.dtype:
35+
name = name.removeprefix("torch.")
36+
try:
37+
value = getattr(torch, name)
38+
assert isinstance(value, torch.dtype)
39+
except Exception:
40+
raise ValueError(f"No such torch dtype `torch.{name}`")
41+
42+
return value
43+
44+
# package validation into a schema (which also validates str type)
45+
from_str_schema = core_schema.chain_schema(
46+
[
47+
core_schema.str_schema(),
48+
core_schema.no_info_plain_validator_function(validate_from_str),
49+
]
50+
)
51+
52+
return core_schema.json_or_python_schema(
53+
json_schema=from_str_schema,
54+
python_schema=core_schema.union_schema(
55+
[
56+
# support both torch.dtype or strings
57+
core_schema.is_instance_schema(torch.dtype),
58+
from_str_schema,
59+
]
60+
),
61+
# serialize as `torch.xxx`
62+
serialization=core_schema.plain_serializer_function_ser_schema(
63+
lambda instance: str(instance)
64+
),
65+
)
66+
67+
@classmethod
68+
def __get_pydantic_json_schema__(
69+
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
70+
) -> JsonSchemaValue:
71+
return handler(core_schema.str_schema())
72+
73+
74+
TorchDtype = Annotated[torch.dtype, _TorchDtypeAnnotation]

tests/test_transform/test_transform_config.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from compressed_tensors.transform import TransformArgs, TransformConfig, TransformScheme
1818

1919

20-
@pytest.fixture
21-
def basic_transform_scheme():
20+
@pytest.fixture(scope="module")
21+
def scheme():
2222
targets = ["Embedding"]
2323
location = "input"
2424
basic_args = TransformArgs(targets=targets, location=location)
@@ -29,21 +29,20 @@ def basic_transform_scheme():
2929
)
3030

3131

32-
def test_basic(basic_transform_scheme):
33-
config = TransformConfig(
32+
@pytest.fixture(scope="module")
33+
def config(scheme):
34+
return TransformConfig(
3435
config_groups={
35-
"transform_0": basic_transform_scheme,
36+
"transform_0": scheme,
3637
}
3738
)
39+
40+
41+
def test_basic(config):
3842
assert isinstance(config.config_groups.get("transform_0"), TransformScheme)
3943

4044

41-
def test_to_dict(basic_transform_scheme):
42-
config = TransformConfig(
43-
config_groups={
44-
"transform_0": basic_transform_scheme,
45-
}
46-
)
45+
def test_to_dict(config):
4746
config_dict = config.model_dump()
4847
assert "config_groups" in config_dict.keys()
4948

@@ -69,3 +68,7 @@ def test_multiple_groups():
6968
config = TransformConfig(
7069
config_groups={"transform_0": scheme_1, "transform_1": scheme_2}
7170
)
71+
72+
73+
def test_reload(config):
74+
assert config == TransformConfig.model_validate(config.model_dump())

tests/test_utils/test_type.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
import torch
17+
from compressed_tensors.utils.type import TorchDtype
18+
from pydantic import BaseModel, Field
19+
from pydantic_core._pydantic_core import ValidationError
20+
21+
22+
class DummyModel(BaseModel):
23+
dtype: TorchDtype = Field(default=torch.float32)
24+
25+
26+
@pytest.mark.unit
27+
def test_default_value():
28+
model = DummyModel()
29+
assert model.dtype == torch.float32
30+
31+
32+
@pytest.mark.unit
33+
def test_value_override():
34+
model = DummyModel()
35+
model.dtype = torch.float16
36+
assert model.dtype == torch.float16
37+
38+
39+
@pytest.mark.unit
40+
def test_validation():
41+
DummyModel(dtype=torch.float16)
42+
DummyModel(dtype="torch.float16")
43+
DummyModel(dtype="float16")
44+
45+
with pytest.raises(ValidationError):
46+
model = DummyModel(dtype="notatype")
47+
48+
49+
@pytest.mark.unit
50+
def test_serialization():
51+
model = DummyModel()
52+
assert model.model_dump()["dtype"] == "torch.float32"
53+
assert DummyModel.model_validate(model.model_dump()) == model
54+
55+
model = DummyModel(dtype=torch.float16)
56+
assert model.model_dump()["dtype"] == "torch.float16"
57+
assert DummyModel.model_validate(model.model_dump()) == model
58+
59+
model = DummyModel()
60+
model.dtype = torch.float16
61+
assert model.model_dump()["dtype"] == "torch.float16"
62+
assert DummyModel.model_validate(model.model_dump()) == model
63+
64+
65+
@pytest.mark.unit
66+
def test_deserialization():
67+
dummy_dict = {"dtype": "torch.float16"}
68+
assert DummyModel.model_validate(dummy_dict).dtype == torch.float16
69+
70+
dummy_dict = {"dtype": "float16"}
71+
assert DummyModel.model_validate(dummy_dict).dtype == torch.float16
72+
73+
with pytest.raises(ValueError):
74+
dummy_dict = {"dtype": "notatype"}
75+
DummyModel.model_validate(dummy_dict)
76+
77+
with pytest.raises(ValueError):
78+
dummy_dict = {"dtype": "torch.notatype"}
79+
DummyModel.model_validate(dummy_dict)

0 commit comments

Comments
 (0)