Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,12 @@ build-backend = "setuptools.build_meta"
[tool.black]
line-length = 88
target-version = ['py36']

[tool.pytest.ini_options]
markers = [
"unit: tests to ensure code correctness and regression test functionality",
"smoke: quick tests to check basic functionality",
"sanity: tests to ensure that new changes do not break existing functionality",
"regression: detailed tests to ensure major functions work correctly",
"integration: tests which integrate with a third party service such as HF",
]
22 changes: 14 additions & 8 deletions src/compressed_tensors/transform/factory/hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Optional, Union
from typing import Optional

import torch
from compressed_tensors.transform import TransformArgs, TransformScheme
Expand All @@ -26,7 +25,7 @@
from compressed_tensors.utils import get_execution_device, get_offloaded_device
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
from torch import Tensor, device, dtype
from torch.nn import Linear, Module, Parameter
from torch.nn import Module, Parameter


@TransformFactory.register("hadamard")
Expand Down Expand Up @@ -54,14 +53,14 @@ def create_transform(self, module: Module, args: TransformArgs):
"""
assert hasattr(module, "weight")
size = get_transform_size(module, args.location, self.scheme.head_dim)
dtype = module.weight.dtype
dtype = self.scheme.precision
device = get_offloaded_device(module)
exec_device = get_execution_device(module)

factory_kwargs = {"construct_device": exec_device}
weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
perm = self.perms[weight] if self.scheme.randomize else None
return HadamardTransform(weight, perm, args, type(module))
return HadamardTransform(weight, perm, self.scheme, args, type(module))

def _create_weight(
self,
Expand All @@ -85,15 +84,17 @@ def __init__(
self,
weight: Parameter,
perm: Optional[Parameter],
scheme: TransformScheme,
args: TransformArgs,
module_type: type[torch.nn.Module],
):
super().__init__()
self.weight = weight
self.perm = perm
self.scheme = scheme
self.args = args
self.module_type = module_type
self._scale = math.sqrt(weight.size(0))
self._scale = torch.tensor(weight.size(0), dtype=self.scheme.precision).sqrt()

def forward(self, value: Tensor) -> Tensor:
weight = self.weight
Expand All @@ -105,6 +106,11 @@ def forward(self, value: Tensor) -> Tensor:
weight = weight.T

return (
apply_transform_weight(weight, value, self.args.location, self.module_type)
apply_transform_weight(
weight.to(self.scheme.precision),
value.to(self.scheme.precision),
self.args.location,
self.module_type,
)
/ self._scale
)
).to(value.dtype)
24 changes: 16 additions & 8 deletions src/compressed_tensors/transform/factory/matrix_multiply.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from compressed_tensors.utils import get_offloaded_device
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
from torch import Tensor, device, dtype
from torch.nn import Linear, Module, Parameter
from torch.nn import Module, Parameter


@TransformFactory.register("random-matrix")
Expand Down Expand Up @@ -52,14 +52,14 @@ def create_transform(self, module: Module, args: TransformArgs):
"""
assert hasattr(module, "weight")
size = get_transform_size(module, args.location, self.scheme.head_dim)
dtype = module.weight.dtype
dtype = self.scheme.precision
device = get_offloaded_device(module)

weight = self.weights[size, dtype, device]
if args.inverse:
weight = self.inverses[weight]

return RandomMatrixTransform(weight, args, type(module))
return RandomMatrixTransform(weight, self.scheme, args, type(module))

def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
# TODO: verify that weight is invertible (has non-zero determinant)
Expand All @@ -77,25 +77,33 @@ class RandomMatrixTransform(TransformBase):
def __init__(
self,
weight: Tensor,
scheme: TransformScheme,
args: TransformArgs,
module_type: type[torch.nn.Module],
):
super().__init__()
self.weight = weight # is an inverse if args.inverse
self.scheme = scheme
self.args = args
self.module_type = module_type

def forward(self, value: Tensor) -> Parameter:
return apply_transform_weight(
self.weight, value, self.args.location, self.module_type
)
self.weight.to(self.scheme.precision),
value.to(self.scheme.precision),
self.args.location,
self.module_type,
).to(value.dtype)

def right_inverse(self, value: Tensor) -> Tensor:
inverse = high_precision_invert(self.weight)
return apply_transform_weight(
inverse, value, self.args.location, self.module_type
)
inverse.to(self.scheme.precision),
value.to(self.scheme.precision),
self.args.location,
self.module_type,
).to(value.dtype)


def high_precision_invert(weight: Tensor) -> Tensor:
return torch.linalg.inv(weight.to(torch.float32)).to(weight.dtype)
return torch.linalg.inv(weight.to(torch.float64)).to(weight.dtype)
5 changes: 5 additions & 0 deletions src/compressed_tensors/transform/transform_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

from typing import List, Optional

import torch
from compressed_tensors.transform import TransformArgs
from compressed_tensors.utils import TorchDtype
from pydantic import BaseModel, Field


Expand All @@ -34,10 +36,13 @@ class TransformScheme(BaseModel):
:param randomize: True if uniquely randomized transform weights should be used,
otherwise use identical transform weights where applicable
:param requires_grad: True if weights include gradients for training
:param precision: Precision at which this transform should be applied. This applies
to both weight fusing and online rotations
"""

type: str
apply: List[TransformArgs] = Field(default_factory=list)
randomize: bool = Field(default=False)
requires_grad: bool = Field(default=False)
head_dim: Optional[int] = Field(default=None)
precision: TorchDtype = Field(default=torch.float32)
1 change: 1 addition & 0 deletions src/compressed_tensors/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
from .permute import *
from .safetensors_load import *
from .semi_structured_conversions import *
from .type import *
74 changes: 74 additions & 0 deletions src/compressed_tensors/utils/type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Annotated, Any

import torch
from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import core_schema


__all__ = ["TorchDtype"]


class _TorchDtypeAnnotation:
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
) -> core_schema.CoreSchema:
# support strings of the form `torch.xxx` or `xxx`
def validate_from_str(name: str) -> torch.dtype:
name = name.removeprefix("torch.")
try:
value = getattr(torch, name)
assert isinstance(value, torch.dtype)
except Exception:
raise ValueError(f"No such torch dtype `torch.{name}`")

return value

# package validation into a schema (which also validates str type)
from_str_schema = core_schema.chain_schema(
[
core_schema.str_schema(),
core_schema.no_info_plain_validator_function(validate_from_str),
]
)

return core_schema.json_or_python_schema(
json_schema=from_str_schema,
python_schema=core_schema.union_schema(
[
# support both torch.dtype or strings
core_schema.is_instance_schema(torch.dtype),
from_str_schema,
]
),
# serialize as `torch.xxx`
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: str(instance)
),
)

@classmethod
def __get_pydantic_json_schema__(
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
return handler(core_schema.str_schema())


TorchDtype = Annotated[torch.dtype, _TorchDtypeAnnotation]
25 changes: 14 additions & 11 deletions tests/test_transform/test_transform_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from compressed_tensors.transform import TransformArgs, TransformConfig, TransformScheme


@pytest.fixture
def basic_transform_scheme():
@pytest.fixture(scope="module")
def scheme():
targets = ["Embedding"]
location = "input"
basic_args = TransformArgs(targets=targets, location=location)
Expand All @@ -29,21 +29,20 @@ def basic_transform_scheme():
)


def test_basic(basic_transform_scheme):
config = TransformConfig(
@pytest.fixture(scope="module")
def config(scheme):
return TransformConfig(
config_groups={
"transform_0": basic_transform_scheme,
"transform_0": scheme,
}
)


def test_basic(config):
assert isinstance(config.config_groups.get("transform_0"), TransformScheme)


def test_to_dict(basic_transform_scheme):
config = TransformConfig(
config_groups={
"transform_0": basic_transform_scheme,
}
)
def test_to_dict(config):
config_dict = config.model_dump()
assert "config_groups" in config_dict.keys()

Expand All @@ -69,3 +68,7 @@ def test_multiple_groups():
config = TransformConfig(
config_groups={"transform_0": scheme_1, "transform_1": scheme_2}
)


def test_reload(config):
assert config == TransformConfig.model_validate(config.model_dump())
79 changes: 79 additions & 0 deletions tests/test_utils/test_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch
from compressed_tensors.utils.type import TorchDtype
from pydantic import BaseModel, Field
from pydantic_core._pydantic_core import ValidationError


class DummyModel(BaseModel):
dtype: TorchDtype = Field(default=torch.float32)


@pytest.mark.unit
def test_default_value():
model = DummyModel()
assert model.dtype == torch.float32


@pytest.mark.unit
def test_value_override():
model = DummyModel()
model.dtype = torch.float16
assert model.dtype == torch.float16


@pytest.mark.unit
def test_validation():
DummyModel(dtype=torch.float16)
DummyModel(dtype="torch.float16")
DummyModel(dtype="float16")

with pytest.raises(ValidationError):
model = DummyModel(dtype="notatype")


@pytest.mark.unit
def test_serialization():
model = DummyModel()
assert model.model_dump()["dtype"] == "torch.float32"
assert DummyModel.model_validate(model.model_dump()) == model

model = DummyModel(dtype=torch.float16)
assert model.model_dump()["dtype"] == "torch.float16"
assert DummyModel.model_validate(model.model_dump()) == model

model = DummyModel()
model.dtype = torch.float16
assert model.model_dump()["dtype"] == "torch.float16"
assert DummyModel.model_validate(model.model_dump()) == model


@pytest.mark.unit
def test_deserialization():
dummy_dict = {"dtype": "torch.float16"}
assert DummyModel.model_validate(dummy_dict).dtype == torch.float16

dummy_dict = {"dtype": "float16"}
assert DummyModel.model_validate(dummy_dict).dtype == torch.float16

with pytest.raises(ValueError):
dummy_dict = {"dtype": "notatype"}
DummyModel.model_validate(dummy_dict)

with pytest.raises(ValueError):
dummy_dict = {"dtype": "torch.notatype"}
DummyModel.model_validate(dummy_dict)