From 90ea08ff1cd790fc91ca731ca46d54522a0bbdd5 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 5 Aug 2025 11:56:36 -0400 Subject: [PATCH 1/5] support precision and torch dtype Signed-off-by: Kyle Sayers --- pyproject.toml | 9 +++ .../transform/factory/hadamard.py | 21 +++-- .../transform/factory/matrix_multiply.py | 21 +++-- .../transform/transform_scheme.py | 3 + src/compressed_tensors/utils/__init__.py | 1 + src/compressed_tensors/utils/type.py | 74 +++++++++++++++++ tests/test_transform/test_transform_config.py | 25 +++--- tests/test_utils/test_type.py | 79 +++++++++++++++++++ 8 files changed, 209 insertions(+), 24 deletions(-) create mode 100644 src/compressed_tensors/utils/type.py create mode 100644 tests/test_utils/test_type.py diff --git a/pyproject.toml b/pyproject.toml index b0f019be..1238a376 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,3 +5,12 @@ build-backend = "setuptools.build_meta" [tool.black] line-length = 88 target-version = ['py36'] + +[tool.pytest.ini_options] +markers = [ + "unit: tests to ensure code correctness and regression test functionality", + "smoke: quick tests to check basic functionality", + "sanity: tests to ensure that new changes do not break existing functionality", + "regression: detailed tests to ensure major functions work correctly", + "integration: tests which integrate with a third party service such as HF", +] \ No newline at end of file diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index 02ebd89b..4a27c8c6 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math -from typing import Optional, Union +from typing import Optional import torch from compressed_tensors.transform import TransformArgs, TransformScheme @@ -26,7 +25,7 @@ from compressed_tensors.utils import get_execution_device, get_offloaded_device from compressed_tensors.utils.helpers import ParameterizedDefaultDict from torch import Tensor, device, dtype -from torch.nn import Linear, Module, Parameter +from torch.nn import Module, Parameter @TransformFactory.register("hadamard") @@ -57,11 +56,12 @@ def create_transform(self, module: Module, args: TransformArgs): dtype = module.weight.dtype device = get_offloaded_device(module) exec_device = get_execution_device(module) + precision = self.scheme.precision factory_kwargs = {"construct_device": exec_device} weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs) perm = self.perms[weight] if self.scheme.randomize else None - return HadamardTransform(weight, perm, args, type(module)) + return HadamardTransform(weight, perm, args, precision, type(module)) def _create_weight( self, @@ -86,14 +86,16 @@ def __init__( weight: Parameter, perm: Optional[Parameter], args: TransformArgs, + precision: torch.dtype, module_type: type[torch.nn.Module], ): super().__init__() self.weight = weight self.perm = perm self.args = args + self.precision = precision self.module_type = module_type - self._scale = math.sqrt(weight.size(0)) + self._scale = torch.tensor(weight.size(0), dtype=self.precision).sqrt() def forward(self, value: Tensor) -> Tensor: weight = self.weight @@ -105,6 +107,11 @@ def forward(self, value: Tensor) -> Tensor: weight = weight.T return ( - apply_transform_weight(weight, value, self.args.location, self.module_type) + apply_transform_weight( + weight.to(self.precision), + value.to(self.precision), + self.args.location, + self.module_type, + ) / self._scale - ) + ).to(value.dtype) diff --git a/src/compressed_tensors/transform/factory/matrix_multiply.py b/src/compressed_tensors/transform/factory/matrix_multiply.py index 8b829451..057b14d6 100644 --- a/src/compressed_tensors/transform/factory/matrix_multiply.py +++ b/src/compressed_tensors/transform/factory/matrix_multiply.py @@ -54,12 +54,13 @@ def create_transform(self, module: Module, args: TransformArgs): size = get_transform_size(module, args.location, self.scheme.head_dim) dtype = module.weight.dtype device = get_offloaded_device(module) + precision = self.scheme.precision weight = self.weights[size, dtype, device] if args.inverse: weight = self.inverses[weight] - return RandomMatrixTransform(weight, args, type(module)) + return RandomMatrixTransform(weight, args, precision, type(module)) def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter: # TODO: verify that weight is invertible (has non-zero determinant) @@ -78,24 +79,32 @@ def __init__( self, weight: Tensor, args: TransformArgs, + precision: torch.dtype, module_type: type[torch.nn.Module], ): super().__init__() self.weight = weight # is an inverse if args.inverse self.args = args + self.precision = precision self.module_type = module_type def forward(self, value: Tensor) -> Parameter: return apply_transform_weight( - self.weight, value, self.args.location, self.module_type - ) + self.weight.to(self.precision), + value.to(self.precision), + self.args.location, + self.module_type, + ).to(value.dtype) def right_inverse(self, value: Tensor) -> Tensor: inverse = high_precision_invert(self.weight) return apply_transform_weight( - inverse, value, self.args.location, self.module_type - ) + inverse.to(self.precision), + value.to(self.precision), + self.args.location, + self.module_type, + ).to(value.dtype) def high_precision_invert(weight: Tensor) -> Tensor: - return torch.linalg.inv(weight.to(torch.float32)).to(weight.dtype) + return torch.linalg.inv(weight.to(torch.float64)).to(weight.dtype) diff --git a/src/compressed_tensors/transform/transform_scheme.py b/src/compressed_tensors/transform/transform_scheme.py index 1620c541..0094b819 100644 --- a/src/compressed_tensors/transform/transform_scheme.py +++ b/src/compressed_tensors/transform/transform_scheme.py @@ -14,7 +14,9 @@ from typing import List, Optional +import torch from compressed_tensors.transform import TransformArgs +from compressed_tensors.utils import TorchDtype from pydantic import BaseModel, Field @@ -41,3 +43,4 @@ class TransformScheme(BaseModel): randomize: bool = Field(default=False) requires_grad: bool = Field(default=False) head_dim: Optional[int] = Field(default=None) + precision: TorchDtype = Field(default=torch.bfloat16) diff --git a/src/compressed_tensors/utils/__init__.py b/src/compressed_tensors/utils/__init__.py index c5f401d1..8763e6ee 100644 --- a/src/compressed_tensors/utils/__init__.py +++ b/src/compressed_tensors/utils/__init__.py @@ -21,3 +21,4 @@ from .permute import * from .safetensors_load import * from .semi_structured_conversions import * +from .type import * diff --git a/src/compressed_tensors/utils/type.py b/src/compressed_tensors/utils/type.py new file mode 100644 index 00000000..90453721 --- /dev/null +++ b/src/compressed_tensors/utils/type.py @@ -0,0 +1,74 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Annotated, Any + +import torch +from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler +from pydantic.json_schema import JsonSchemaValue +from pydantic_core import core_schema + + +__all__ = ["TorchDtype"] + + +class _TorchDtypeAnnotation: + @classmethod + def __get_pydantic_core_schema__( + cls, + _source_type: Any, + _handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: + # support strings of the form `torch.xxx` or `xxx` + def validate_from_str(name: str) -> torch.dtype: + name = name.removeprefix("torch.") + try: + value = getattr(torch, name) + assert isinstance(value, torch.dtype) + except AttributeError: + raise ValueError(f"No such torch dtype `torch.{name}`") + + return value + + # package validation into a schema (which also validates str type) + from_str_schema = core_schema.chain_schema( + [ + core_schema.str_schema(), + core_schema.no_info_plain_validator_function(validate_from_str), + ] + ) + + return core_schema.json_or_python_schema( + json_schema=from_str_schema, + python_schema=core_schema.union_schema( + [ + # support both torch.dtype or strings + core_schema.is_instance_schema(torch.dtype), + from_str_schema, + ] + ), + # serialize as `torch.xxx` + serialization=core_schema.plain_serializer_function_ser_schema( + lambda instance: str(instance) + ), + ) + + @classmethod + def __get_pydantic_json_schema__( + cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + return handler(core_schema.str_schema()) + + +TorchDtype = Annotated[torch.dtype, _TorchDtypeAnnotation] diff --git a/tests/test_transform/test_transform_config.py b/tests/test_transform/test_transform_config.py index 52167cfd..ad8e6645 100644 --- a/tests/test_transform/test_transform_config.py +++ b/tests/test_transform/test_transform_config.py @@ -17,8 +17,8 @@ from compressed_tensors.transform import TransformArgs, TransformConfig, TransformScheme -@pytest.fixture -def basic_transform_scheme(): +@pytest.fixture(scope="module") +def scheme(): targets = ["Embedding"] location = "input" basic_args = TransformArgs(targets=targets, location=location) @@ -29,21 +29,20 @@ def basic_transform_scheme(): ) -def test_basic(basic_transform_scheme): - config = TransformConfig( +@pytest.fixture(scope="module") +def config(scheme): + return TransformConfig( config_groups={ - "transform_0": basic_transform_scheme, + "transform_0": scheme, } ) + + +def test_basic(config): assert isinstance(config.config_groups.get("transform_0"), TransformScheme) -def test_to_dict(basic_transform_scheme): - config = TransformConfig( - config_groups={ - "transform_0": basic_transform_scheme, - } - ) +def test_to_dict(config): config_dict = config.model_dump() assert "config_groups" in config_dict.keys() @@ -69,3 +68,7 @@ def test_multiple_groups(): config = TransformConfig( config_groups={"transform_0": scheme_1, "transform_1": scheme_2} ) + + +def test_reload(config): + assert config == TransformConfig.model_validate(config.model_dump()) diff --git a/tests/test_utils/test_type.py b/tests/test_utils/test_type.py new file mode 100644 index 00000000..4eb38091 --- /dev/null +++ b/tests/test_utils/test_type.py @@ -0,0 +1,79 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from compressed_tensors.utils.type import TorchDtype +from pydantic import BaseModel, Field +from pydantic_core._pydantic_core import ValidationError + + +class DummyModel(BaseModel): + dtype: TorchDtype = Field(default=torch.float32) + + +@pytest.mark.unit +def test_default_value(): + model = DummyModel() + assert model.dtype == torch.float32 + + +@pytest.mark.unit +def test_value_override(): + model = DummyModel() + model.dtype = torch.float16 + assert model.dtype == torch.float16 + + +@pytest.mark.unit +def test_validation(): + DummyModel(dtype=torch.float16) + DummyModel(dtype="torch.float16") + DummyModel(dtype="float16") + + with pytest.raises(ValidationError): + model = DummyModel(dtype="notatype") + + +@pytest.mark.unit +def test_serialization(): + model = DummyModel() + assert model.model_dump()["dtype"] == "torch.float32" + assert DummyModel.model_validate(model.model_dump()) == model + + model = DummyModel(dtype=torch.float16) + assert model.model_dump()["dtype"] == "torch.float16" + assert DummyModel.model_validate(model.model_dump()) == model + + model = DummyModel() + model.dtype = torch.float16 + assert model.model_dump()["dtype"] == "torch.float16" + assert DummyModel.model_validate(model.model_dump()) == model + + +@pytest.mark.unit +def test_deserialization(): + dummy_dict = {"dtype": "torch.float16"} + assert DummyModel.model_validate(dummy_dict).dtype == torch.float16 + + dummy_dict = {"dtype": "float16"} + assert DummyModel.model_validate(dummy_dict).dtype == torch.float16 + + with pytest.raises(ValueError): + dummy_dict = {"dtype": "notatype"} + DummyModel.model_validate(dummy_dict) + + with pytest.raises(ValueError): + dummy_dict = {"dtype": "torch.notatype"} + DummyModel.model_validate(dummy_dict) From 5db0e13095ea8f9723bd4892aeb94ff27a9d7306 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 5 Aug 2025 16:13:29 +0000 Subject: [PATCH 2/5] cleanup, construct on dtype, change default Signed-off-by: Kyle Sayers --- .../transform/factory/hadamard.py | 15 +++++++-------- .../transform/factory/matrix_multiply.py | 16 ++++++++-------- .../transform/transform_scheme.py | 4 +++- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index 4a27c8c6..c8784126 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -53,15 +53,14 @@ def create_transform(self, module: Module, args: TransformArgs): """ assert hasattr(module, "weight") size = get_transform_size(module, args.location, self.scheme.head_dim) - dtype = module.weight.dtype + dtype = self.scheme.precision device = get_offloaded_device(module) exec_device = get_execution_device(module) - precision = self.scheme.precision factory_kwargs = {"construct_device": exec_device} weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs) perm = self.perms[weight] if self.scheme.randomize else None - return HadamardTransform(weight, perm, args, precision, type(module)) + return HadamardTransform(weight, perm, self.scheme, args, type(module)) def _create_weight( self, @@ -85,17 +84,17 @@ def __init__( self, weight: Parameter, perm: Optional[Parameter], + scheme: TransformScheme, args: TransformArgs, - precision: torch.dtype, module_type: type[torch.nn.Module], ): super().__init__() self.weight = weight self.perm = perm + self.scheme = scheme self.args = args - self.precision = precision self.module_type = module_type - self._scale = torch.tensor(weight.size(0), dtype=self.precision).sqrt() + self._scale = torch.tensor(weight.size(0), dtype=self.scheme.precision).sqrt() def forward(self, value: Tensor) -> Tensor: weight = self.weight @@ -108,8 +107,8 @@ def forward(self, value: Tensor) -> Tensor: return ( apply_transform_weight( - weight.to(self.precision), - value.to(self.precision), + weight.to(self.scheme.precision), + value.to(self.scheme.precision), self.args.location, self.module_type, ) diff --git a/src/compressed_tensors/transform/factory/matrix_multiply.py b/src/compressed_tensors/transform/factory/matrix_multiply.py index 057b14d6..d8537485 100644 --- a/src/compressed_tensors/transform/factory/matrix_multiply.py +++ b/src/compressed_tensors/transform/factory/matrix_multiply.py @@ -24,7 +24,7 @@ from compressed_tensors.utils import get_offloaded_device from compressed_tensors.utils.helpers import ParameterizedDefaultDict from torch import Tensor, device, dtype -from torch.nn import Linear, Module, Parameter +from torch.nn import Module, Parameter @TransformFactory.register("random-matrix") @@ -52,7 +52,7 @@ def create_transform(self, module: Module, args: TransformArgs): """ assert hasattr(module, "weight") size = get_transform_size(module, args.location, self.scheme.head_dim) - dtype = module.weight.dtype + dtype = self.scheme.precision device = get_offloaded_device(module) precision = self.scheme.precision @@ -78,20 +78,20 @@ class RandomMatrixTransform(TransformBase): def __init__( self, weight: Tensor, + scheme: TransformScheme, args: TransformArgs, - precision: torch.dtype, module_type: type[torch.nn.Module], ): super().__init__() self.weight = weight # is an inverse if args.inverse + self.scheme = scheme self.args = args - self.precision = precision self.module_type = module_type def forward(self, value: Tensor) -> Parameter: return apply_transform_weight( - self.weight.to(self.precision), - value.to(self.precision), + self.weight.to(self.scheme.precision), + value.to(self.scheme.precision), self.args.location, self.module_type, ).to(value.dtype) @@ -99,8 +99,8 @@ def forward(self, value: Tensor) -> Parameter: def right_inverse(self, value: Tensor) -> Tensor: inverse = high_precision_invert(self.weight) return apply_transform_weight( - inverse.to(self.precision), - value.to(self.precision), + inverse.to(self.scheme.precision), + value.to(self.scheme.precision), self.args.location, self.module_type, ).to(value.dtype) diff --git a/src/compressed_tensors/transform/transform_scheme.py b/src/compressed_tensors/transform/transform_scheme.py index 0094b819..f8d26e76 100644 --- a/src/compressed_tensors/transform/transform_scheme.py +++ b/src/compressed_tensors/transform/transform_scheme.py @@ -36,6 +36,8 @@ class TransformScheme(BaseModel): :param randomize: True if uniquely randomized transform weights should be used, otherwise use identical transform weights where applicable :param requires_grad: True if weights include gradients for training + :param precision: Precision at which this transform should be applied. This applies + to both weight fusing and online rotations """ type: str @@ -43,4 +45,4 @@ class TransformScheme(BaseModel): randomize: bool = Field(default=False) requires_grad: bool = Field(default=False) head_dim: Optional[int] = Field(default=None) - precision: TorchDtype = Field(default=torch.bfloat16) + precision: TorchDtype = Field(default=torch.float32) From 08f58ded46e087b00bb069b5cd1005b6b439c18a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 6 Aug 2025 01:19:50 -0400 Subject: [PATCH 3/5] Fix typo --- src/compressed_tensors/transform/factory/matrix_multiply.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/compressed_tensors/transform/factory/matrix_multiply.py b/src/compressed_tensors/transform/factory/matrix_multiply.py index d8537485..0c573b2e 100644 --- a/src/compressed_tensors/transform/factory/matrix_multiply.py +++ b/src/compressed_tensors/transform/factory/matrix_multiply.py @@ -54,13 +54,12 @@ def create_transform(self, module: Module, args: TransformArgs): size = get_transform_size(module, args.location, self.scheme.head_dim) dtype = self.scheme.precision device = get_offloaded_device(module) - precision = self.scheme.precision weight = self.weights[size, dtype, device] if args.inverse: weight = self.inverses[weight] - return RandomMatrixTransform(weight, args, precision, type(module)) + return RandomMatrixTransform(weight, self.scheme, args, type(module)) def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter: # TODO: verify that weight is invertible (has non-zero determinant) From f7011ac5649869eaa65e2f40c62fc69dc566d287 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 6 Aug 2025 01:23:41 -0400 Subject: [PATCH 4/5] Fix exception --- src/compressed_tensors/utils/type.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/utils/type.py b/src/compressed_tensors/utils/type.py index 90453721..636bffd7 100644 --- a/src/compressed_tensors/utils/type.py +++ b/src/compressed_tensors/utils/type.py @@ -36,7 +36,7 @@ def validate_from_str(name: str) -> torch.dtype: try: value = getattr(torch, name) assert isinstance(value, torch.dtype) - except AttributeError: + except Exception: raise ValueError(f"No such torch dtype `torch.{name}`") return value From 064eacf25237b2b64741014f46f07cb8f9383072 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 11 Aug 2025 10:41:01 -0400 Subject: [PATCH 5/5] do offline in float64 Signed-off-by: Kyle Sayers --- src/compressed_tensors/transform/factory/hadamard.py | 5 +++-- .../transform/factory/matrix_multiply.py | 9 +++++---- src/compressed_tensors/transform/transform_args.py | 6 ++++++ src/compressed_tensors/transform/transform_scheme.py | 4 ++-- 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index c8784126..42611967 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -95,6 +95,7 @@ def __init__( self.args = args self.module_type = module_type self._scale = torch.tensor(weight.size(0), dtype=self.scheme.precision).sqrt() + self._precision = scheme.precision if args.is_online() else torch.float64 def forward(self, value: Tensor) -> Tensor: weight = self.weight @@ -107,8 +108,8 @@ def forward(self, value: Tensor) -> Tensor: return ( apply_transform_weight( - weight.to(self.scheme.precision), - value.to(self.scheme.precision), + weight.to(self._precision), + value.to(self._precision), self.args.location, self.module_type, ) diff --git a/src/compressed_tensors/transform/factory/matrix_multiply.py b/src/compressed_tensors/transform/factory/matrix_multiply.py index 0c573b2e..7e6103f0 100644 --- a/src/compressed_tensors/transform/factory/matrix_multiply.py +++ b/src/compressed_tensors/transform/factory/matrix_multiply.py @@ -86,11 +86,12 @@ def __init__( self.scheme = scheme self.args = args self.module_type = module_type + self._precision = scheme.precision if args.is_online() else torch.float64 def forward(self, value: Tensor) -> Parameter: return apply_transform_weight( - self.weight.to(self.scheme.precision), - value.to(self.scheme.precision), + self.weight.to(self._precision), + value.to(self._precision), self.args.location, self.module_type, ).to(value.dtype) @@ -98,8 +99,8 @@ def forward(self, value: Tensor) -> Parameter: def right_inverse(self, value: Tensor) -> Tensor: inverse = high_precision_invert(self.weight) return apply_transform_weight( - inverse.to(self.scheme.precision), - value.to(self.scheme.precision), + inverse.to(self._precision), + value.to(self._precision), self.args.location, self.module_type, ).to(value.dtype) diff --git a/src/compressed_tensors/transform/transform_args.py b/src/compressed_tensors/transform/transform_args.py index e94d4d2d..8d2f8973 100644 --- a/src/compressed_tensors/transform/transform_args.py +++ b/src/compressed_tensors/transform/transform_args.py @@ -68,3 +68,9 @@ def wrap_singleton(cls, value): if isinstance(value, str): return [value] return value + + def is_online(self) -> bool: + return self.location not in ( + TransformLocation.WEIGHT_INPUT, + TransformLocation.WEIGHT_OUTPUT, + ) diff --git a/src/compressed_tensors/transform/transform_scheme.py b/src/compressed_tensors/transform/transform_scheme.py index f8d26e76..7236b9dd 100644 --- a/src/compressed_tensors/transform/transform_scheme.py +++ b/src/compressed_tensors/transform/transform_scheme.py @@ -36,8 +36,8 @@ class TransformScheme(BaseModel): :param randomize: True if uniquely randomized transform weights should be used, otherwise use identical transform weights where applicable :param requires_grad: True if weights include gradients for training - :param precision: Precision at which this transform should be applied. This applies - to both weight fusing and online rotations + :param precision: Precision at which this transform should be applied during online + rotations. Fused (offline) rotations are always performed in float64 """ type: str