diff --git a/pyproject.toml b/pyproject.toml index b0f019be..1238a376 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,3 +5,12 @@ build-backend = "setuptools.build_meta" [tool.black] line-length = 88 target-version = ['py36'] + +[tool.pytest.ini_options] +markers = [ + "unit: tests to ensure code correctness and regression test functionality", + "smoke: quick tests to check basic functionality", + "sanity: tests to ensure that new changes do not break existing functionality", + "regression: detailed tests to ensure major functions work correctly", + "integration: tests which integrate with a third party service such as HF", +] \ No newline at end of file diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index 02ebd89b..42611967 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math -from typing import Optional, Union +from typing import Optional import torch from compressed_tensors.transform import TransformArgs, TransformScheme @@ -26,7 +25,7 @@ from compressed_tensors.utils import get_execution_device, get_offloaded_device from compressed_tensors.utils.helpers import ParameterizedDefaultDict from torch import Tensor, device, dtype -from torch.nn import Linear, Module, Parameter +from torch.nn import Module, Parameter @TransformFactory.register("hadamard") @@ -54,14 +53,14 @@ def create_transform(self, module: Module, args: TransformArgs): """ assert hasattr(module, "weight") size = get_transform_size(module, args.location, self.scheme.head_dim) - dtype = module.weight.dtype + dtype = self.scheme.precision device = get_offloaded_device(module) exec_device = get_execution_device(module) factory_kwargs = {"construct_device": exec_device} weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs) perm = self.perms[weight] if self.scheme.randomize else None - return HadamardTransform(weight, perm, args, type(module)) + return HadamardTransform(weight, perm, self.scheme, args, type(module)) def _create_weight( self, @@ -85,15 +84,18 @@ def __init__( self, weight: Parameter, perm: Optional[Parameter], + scheme: TransformScheme, args: TransformArgs, module_type: type[torch.nn.Module], ): super().__init__() self.weight = weight self.perm = perm + self.scheme = scheme self.args = args self.module_type = module_type - self._scale = math.sqrt(weight.size(0)) + self._scale = torch.tensor(weight.size(0), dtype=self.scheme.precision).sqrt() + self._precision = scheme.precision if args.is_online() else torch.float64 def forward(self, value: Tensor) -> Tensor: weight = self.weight @@ -105,6 +107,11 @@ def forward(self, value: Tensor) -> Tensor: weight = weight.T return ( - apply_transform_weight(weight, value, self.args.location, self.module_type) + apply_transform_weight( + weight.to(self._precision), + value.to(self._precision), + self.args.location, + self.module_type, + ) / self._scale - ) + ).to(value.dtype) diff --git a/src/compressed_tensors/transform/factory/matrix_multiply.py b/src/compressed_tensors/transform/factory/matrix_multiply.py index 8b829451..7e6103f0 100644 --- a/src/compressed_tensors/transform/factory/matrix_multiply.py +++ b/src/compressed_tensors/transform/factory/matrix_multiply.py @@ -24,7 +24,7 @@ from compressed_tensors.utils import get_offloaded_device from compressed_tensors.utils.helpers import ParameterizedDefaultDict from torch import Tensor, device, dtype -from torch.nn import Linear, Module, Parameter +from torch.nn import Module, Parameter @TransformFactory.register("random-matrix") @@ -52,14 +52,14 @@ def create_transform(self, module: Module, args: TransformArgs): """ assert hasattr(module, "weight") size = get_transform_size(module, args.location, self.scheme.head_dim) - dtype = module.weight.dtype + dtype = self.scheme.precision device = get_offloaded_device(module) weight = self.weights[size, dtype, device] if args.inverse: weight = self.inverses[weight] - return RandomMatrixTransform(weight, args, type(module)) + return RandomMatrixTransform(weight, self.scheme, args, type(module)) def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter: # TODO: verify that weight is invertible (has non-zero determinant) @@ -77,25 +77,34 @@ class RandomMatrixTransform(TransformBase): def __init__( self, weight: Tensor, + scheme: TransformScheme, args: TransformArgs, module_type: type[torch.nn.Module], ): super().__init__() self.weight = weight # is an inverse if args.inverse + self.scheme = scheme self.args = args self.module_type = module_type + self._precision = scheme.precision if args.is_online() else torch.float64 def forward(self, value: Tensor) -> Parameter: return apply_transform_weight( - self.weight, value, self.args.location, self.module_type - ) + self.weight.to(self._precision), + value.to(self._precision), + self.args.location, + self.module_type, + ).to(value.dtype) def right_inverse(self, value: Tensor) -> Tensor: inverse = high_precision_invert(self.weight) return apply_transform_weight( - inverse, value, self.args.location, self.module_type - ) + inverse.to(self._precision), + value.to(self._precision), + self.args.location, + self.module_type, + ).to(value.dtype) def high_precision_invert(weight: Tensor) -> Tensor: - return torch.linalg.inv(weight.to(torch.float32)).to(weight.dtype) + return torch.linalg.inv(weight.to(torch.float64)).to(weight.dtype) diff --git a/src/compressed_tensors/transform/transform_args.py b/src/compressed_tensors/transform/transform_args.py index e94d4d2d..8d2f8973 100644 --- a/src/compressed_tensors/transform/transform_args.py +++ b/src/compressed_tensors/transform/transform_args.py @@ -68,3 +68,9 @@ def wrap_singleton(cls, value): if isinstance(value, str): return [value] return value + + def is_online(self) -> bool: + return self.location not in ( + TransformLocation.WEIGHT_INPUT, + TransformLocation.WEIGHT_OUTPUT, + ) diff --git a/src/compressed_tensors/transform/transform_scheme.py b/src/compressed_tensors/transform/transform_scheme.py index 1620c541..7236b9dd 100644 --- a/src/compressed_tensors/transform/transform_scheme.py +++ b/src/compressed_tensors/transform/transform_scheme.py @@ -14,7 +14,9 @@ from typing import List, Optional +import torch from compressed_tensors.transform import TransformArgs +from compressed_tensors.utils import TorchDtype from pydantic import BaseModel, Field @@ -34,6 +36,8 @@ class TransformScheme(BaseModel): :param randomize: True if uniquely randomized transform weights should be used, otherwise use identical transform weights where applicable :param requires_grad: True if weights include gradients for training + :param precision: Precision at which this transform should be applied during online + rotations. Fused (offline) rotations are always performed in float64 """ type: str @@ -41,3 +45,4 @@ class TransformScheme(BaseModel): randomize: bool = Field(default=False) requires_grad: bool = Field(default=False) head_dim: Optional[int] = Field(default=None) + precision: TorchDtype = Field(default=torch.float32) diff --git a/src/compressed_tensors/utils/__init__.py b/src/compressed_tensors/utils/__init__.py index c5f401d1..8763e6ee 100644 --- a/src/compressed_tensors/utils/__init__.py +++ b/src/compressed_tensors/utils/__init__.py @@ -21,3 +21,4 @@ from .permute import * from .safetensors_load import * from .semi_structured_conversions import * +from .type import * diff --git a/src/compressed_tensors/utils/type.py b/src/compressed_tensors/utils/type.py new file mode 100644 index 00000000..636bffd7 --- /dev/null +++ b/src/compressed_tensors/utils/type.py @@ -0,0 +1,74 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Annotated, Any + +import torch +from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler +from pydantic.json_schema import JsonSchemaValue +from pydantic_core import core_schema + + +__all__ = ["TorchDtype"] + + +class _TorchDtypeAnnotation: + @classmethod + def __get_pydantic_core_schema__( + cls, + _source_type: Any, + _handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: + # support strings of the form `torch.xxx` or `xxx` + def validate_from_str(name: str) -> torch.dtype: + name = name.removeprefix("torch.") + try: + value = getattr(torch, name) + assert isinstance(value, torch.dtype) + except Exception: + raise ValueError(f"No such torch dtype `torch.{name}`") + + return value + + # package validation into a schema (which also validates str type) + from_str_schema = core_schema.chain_schema( + [ + core_schema.str_schema(), + core_schema.no_info_plain_validator_function(validate_from_str), + ] + ) + + return core_schema.json_or_python_schema( + json_schema=from_str_schema, + python_schema=core_schema.union_schema( + [ + # support both torch.dtype or strings + core_schema.is_instance_schema(torch.dtype), + from_str_schema, + ] + ), + # serialize as `torch.xxx` + serialization=core_schema.plain_serializer_function_ser_schema( + lambda instance: str(instance) + ), + ) + + @classmethod + def __get_pydantic_json_schema__( + cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + return handler(core_schema.str_schema()) + + +TorchDtype = Annotated[torch.dtype, _TorchDtypeAnnotation] diff --git a/tests/test_transform/test_transform_config.py b/tests/test_transform/test_transform_config.py index 52167cfd..ad8e6645 100644 --- a/tests/test_transform/test_transform_config.py +++ b/tests/test_transform/test_transform_config.py @@ -17,8 +17,8 @@ from compressed_tensors.transform import TransformArgs, TransformConfig, TransformScheme -@pytest.fixture -def basic_transform_scheme(): +@pytest.fixture(scope="module") +def scheme(): targets = ["Embedding"] location = "input" basic_args = TransformArgs(targets=targets, location=location) @@ -29,21 +29,20 @@ def basic_transform_scheme(): ) -def test_basic(basic_transform_scheme): - config = TransformConfig( +@pytest.fixture(scope="module") +def config(scheme): + return TransformConfig( config_groups={ - "transform_0": basic_transform_scheme, + "transform_0": scheme, } ) + + +def test_basic(config): assert isinstance(config.config_groups.get("transform_0"), TransformScheme) -def test_to_dict(basic_transform_scheme): - config = TransformConfig( - config_groups={ - "transform_0": basic_transform_scheme, - } - ) +def test_to_dict(config): config_dict = config.model_dump() assert "config_groups" in config_dict.keys() @@ -69,3 +68,7 @@ def test_multiple_groups(): config = TransformConfig( config_groups={"transform_0": scheme_1, "transform_1": scheme_2} ) + + +def test_reload(config): + assert config == TransformConfig.model_validate(config.model_dump()) diff --git a/tests/test_utils/test_type.py b/tests/test_utils/test_type.py new file mode 100644 index 00000000..4eb38091 --- /dev/null +++ b/tests/test_utils/test_type.py @@ -0,0 +1,79 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from compressed_tensors.utils.type import TorchDtype +from pydantic import BaseModel, Field +from pydantic_core._pydantic_core import ValidationError + + +class DummyModel(BaseModel): + dtype: TorchDtype = Field(default=torch.float32) + + +@pytest.mark.unit +def test_default_value(): + model = DummyModel() + assert model.dtype == torch.float32 + + +@pytest.mark.unit +def test_value_override(): + model = DummyModel() + model.dtype = torch.float16 + assert model.dtype == torch.float16 + + +@pytest.mark.unit +def test_validation(): + DummyModel(dtype=torch.float16) + DummyModel(dtype="torch.float16") + DummyModel(dtype="float16") + + with pytest.raises(ValidationError): + model = DummyModel(dtype="notatype") + + +@pytest.mark.unit +def test_serialization(): + model = DummyModel() + assert model.model_dump()["dtype"] == "torch.float32" + assert DummyModel.model_validate(model.model_dump()) == model + + model = DummyModel(dtype=torch.float16) + assert model.model_dump()["dtype"] == "torch.float16" + assert DummyModel.model_validate(model.model_dump()) == model + + model = DummyModel() + model.dtype = torch.float16 + assert model.model_dump()["dtype"] == "torch.float16" + assert DummyModel.model_validate(model.model_dump()) == model + + +@pytest.mark.unit +def test_deserialization(): + dummy_dict = {"dtype": "torch.float16"} + assert DummyModel.model_validate(dummy_dict).dtype == torch.float16 + + dummy_dict = {"dtype": "float16"} + assert DummyModel.model_validate(dummy_dict).dtype == torch.float16 + + with pytest.raises(ValueError): + dummy_dict = {"dtype": "notatype"} + DummyModel.model_validate(dummy_dict) + + with pytest.raises(ValueError): + dummy_dict = {"dtype": "torch.notatype"} + DummyModel.model_validate(dummy_dict)