Skip to content

[Transform] [Utils] Support precision, add torch dtype validation #414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,12 @@ build-backend = "setuptools.build_meta"
[tool.black]
line-length = 88
target-version = ['py36']

[tool.pytest.ini_options]
markers = [
"unit: tests to ensure code correctness and regression test functionality",
"smoke: quick tests to check basic functionality",
"sanity: tests to ensure that new changes do not break existing functionality",
"regression: detailed tests to ensure major functions work correctly",
"integration: tests which integrate with a third party service such as HF",
]
23 changes: 15 additions & 8 deletions src/compressed_tensors/transform/factory/hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Optional, Union
from typing import Optional

import torch
from compressed_tensors.transform import TransformArgs, TransformScheme
Expand All @@ -26,7 +25,7 @@
from compressed_tensors.utils import get_execution_device, get_offloaded_device
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
from torch import Tensor, device, dtype
from torch.nn import Linear, Module, Parameter
from torch.nn import Module, Parameter


@TransformFactory.register("hadamard")
Expand Down Expand Up @@ -54,14 +53,14 @@ def create_transform(self, module: Module, args: TransformArgs):
"""
assert hasattr(module, "weight")
size = get_transform_size(module, args.location, self.scheme.head_dim)
dtype = module.weight.dtype
dtype = self.scheme.precision
device = get_offloaded_device(module)
exec_device = get_execution_device(module)

factory_kwargs = {"construct_device": exec_device}
weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
perm = self.perms[weight] if self.scheme.randomize else None
return HadamardTransform(weight, perm, args, type(module))
return HadamardTransform(weight, perm, self.scheme, args, type(module))

def _create_weight(
self,
Expand All @@ -85,15 +84,18 @@ def __init__(
self,
weight: Parameter,
perm: Optional[Parameter],
scheme: TransformScheme,
args: TransformArgs,
module_type: type[torch.nn.Module],
):
super().__init__()
self.weight = weight
self.perm = perm
self.scheme = scheme
self.args = args
self.module_type = module_type
self._scale = math.sqrt(weight.size(0))
self._scale = torch.tensor(weight.size(0), dtype=self.scheme.precision).sqrt()
self._precision = scheme.precision if args.is_online() else torch.float64

def forward(self, value: Tensor) -> Tensor:
weight = self.weight
Expand All @@ -105,6 +107,11 @@ def forward(self, value: Tensor) -> Tensor:
weight = weight.T

return (
apply_transform_weight(weight, value, self.args.location, self.module_type)
apply_transform_weight(
weight.to(self._precision),
value.to(self._precision),
self.args.location,
self.module_type,
)
/ self._scale
)
).to(value.dtype)
25 changes: 17 additions & 8 deletions src/compressed_tensors/transform/factory/matrix_multiply.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from compressed_tensors.utils import get_offloaded_device
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
from torch import Tensor, device, dtype
from torch.nn import Linear, Module, Parameter
from torch.nn import Module, Parameter


@TransformFactory.register("random-matrix")
Expand Down Expand Up @@ -52,14 +52,14 @@ def create_transform(self, module: Module, args: TransformArgs):
"""
assert hasattr(module, "weight")
size = get_transform_size(module, args.location, self.scheme.head_dim)
dtype = module.weight.dtype
dtype = self.scheme.precision
device = get_offloaded_device(module)

weight = self.weights[size, dtype, device]
if args.inverse:
weight = self.inverses[weight]

return RandomMatrixTransform(weight, args, type(module))
return RandomMatrixTransform(weight, self.scheme, args, type(module))

def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
# TODO: verify that weight is invertible (has non-zero determinant)
Expand All @@ -77,25 +77,34 @@ class RandomMatrixTransform(TransformBase):
def __init__(
self,
weight: Tensor,
scheme: TransformScheme,
args: TransformArgs,
module_type: type[torch.nn.Module],
):
super().__init__()
self.weight = weight # is an inverse if args.inverse
self.scheme = scheme
self.args = args
self.module_type = module_type
self._precision = scheme.precision if args.is_online() else torch.float64

def forward(self, value: Tensor) -> Parameter:
return apply_transform_weight(
self.weight, value, self.args.location, self.module_type
)
self.weight.to(self._precision),
value.to(self._precision),
self.args.location,
self.module_type,
).to(value.dtype)

def right_inverse(self, value: Tensor) -> Tensor:
inverse = high_precision_invert(self.weight)
return apply_transform_weight(
inverse, value, self.args.location, self.module_type
)
inverse.to(self._precision),
value.to(self._precision),
self.args.location,
self.module_type,
).to(value.dtype)


def high_precision_invert(weight: Tensor) -> Tensor:
return torch.linalg.inv(weight.to(torch.float32)).to(weight.dtype)
return torch.linalg.inv(weight.to(torch.float64)).to(weight.dtype)
6 changes: 6 additions & 0 deletions src/compressed_tensors/transform/transform_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,9 @@ def wrap_singleton(cls, value):
if isinstance(value, str):
return [value]
return value

def is_online(self) -> bool:
return self.location not in (
TransformLocation.WEIGHT_INPUT,
TransformLocation.WEIGHT_OUTPUT,
)
5 changes: 5 additions & 0 deletions src/compressed_tensors/transform/transform_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

from typing import List, Optional

import torch
from compressed_tensors.transform import TransformArgs
from compressed_tensors.utils import TorchDtype
from pydantic import BaseModel, Field


Expand All @@ -34,10 +36,13 @@ class TransformScheme(BaseModel):
:param randomize: True if uniquely randomized transform weights should be used,
otherwise use identical transform weights where applicable
:param requires_grad: True if weights include gradients for training
:param precision: Precision at which this transform should be applied during online
rotations. Fused (offline) rotations are always performed in float64
"""

type: str
apply: List[TransformArgs] = Field(default_factory=list)
randomize: bool = Field(default=False)
requires_grad: bool = Field(default=False)
head_dim: Optional[int] = Field(default=None)
precision: TorchDtype = Field(default=torch.float32)
1 change: 1 addition & 0 deletions src/compressed_tensors/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
from .permute import *
from .safetensors_load import *
from .semi_structured_conversions import *
from .type import *
74 changes: 74 additions & 0 deletions src/compressed_tensors/utils/type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Annotated, Any

import torch
from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import core_schema


__all__ = ["TorchDtype"]


class _TorchDtypeAnnotation:
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
) -> core_schema.CoreSchema:
# support strings of the form `torch.xxx` or `xxx`
def validate_from_str(name: str) -> torch.dtype:
name = name.removeprefix("torch.")
try:
value = getattr(torch, name)
assert isinstance(value, torch.dtype)
except Exception:
raise ValueError(f"No such torch dtype `torch.{name}`")

return value

# package validation into a schema (which also validates str type)
from_str_schema = core_schema.chain_schema(
[
core_schema.str_schema(),
core_schema.no_info_plain_validator_function(validate_from_str),
]
)

return core_schema.json_or_python_schema(
json_schema=from_str_schema,
python_schema=core_schema.union_schema(
[
# support both torch.dtype or strings
core_schema.is_instance_schema(torch.dtype),
from_str_schema,
]
),
# serialize as `torch.xxx`
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: str(instance)
),
)

@classmethod
def __get_pydantic_json_schema__(
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
return handler(core_schema.str_schema())


TorchDtype = Annotated[torch.dtype, _TorchDtypeAnnotation]
25 changes: 14 additions & 11 deletions tests/test_transform/test_transform_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from compressed_tensors.transform import TransformArgs, TransformConfig, TransformScheme


@pytest.fixture
def basic_transform_scheme():
@pytest.fixture(scope="module")
def scheme():
targets = ["Embedding"]
location = "input"
basic_args = TransformArgs(targets=targets, location=location)
Expand All @@ -29,21 +29,20 @@ def basic_transform_scheme():
)


def test_basic(basic_transform_scheme):
config = TransformConfig(
@pytest.fixture(scope="module")
def config(scheme):
return TransformConfig(
config_groups={
"transform_0": basic_transform_scheme,
"transform_0": scheme,
}
)


def test_basic(config):
assert isinstance(config.config_groups.get("transform_0"), TransformScheme)


def test_to_dict(basic_transform_scheme):
config = TransformConfig(
config_groups={
"transform_0": basic_transform_scheme,
}
)
def test_to_dict(config):
config_dict = config.model_dump()
assert "config_groups" in config_dict.keys()

Expand All @@ -69,3 +68,7 @@ def test_multiple_groups():
config = TransformConfig(
config_groups={"transform_0": scheme_1, "transform_1": scheme_2}
)


def test_reload(config):
assert config == TransformConfig.model_validate(config.model_dump())
79 changes: 79 additions & 0 deletions tests/test_utils/test_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch
from compressed_tensors.utils.type import TorchDtype
from pydantic import BaseModel, Field
from pydantic_core._pydantic_core import ValidationError


class DummyModel(BaseModel):
dtype: TorchDtype = Field(default=torch.float32)


@pytest.mark.unit
def test_default_value():
model = DummyModel()
assert model.dtype == torch.float32


@pytest.mark.unit
def test_value_override():
model = DummyModel()
model.dtype = torch.float16
assert model.dtype == torch.float16


@pytest.mark.unit
def test_validation():
DummyModel(dtype=torch.float16)
DummyModel(dtype="torch.float16")
DummyModel(dtype="float16")

with pytest.raises(ValidationError):
model = DummyModel(dtype="notatype")


@pytest.mark.unit
def test_serialization():
model = DummyModel()
assert model.model_dump()["dtype"] == "torch.float32"
assert DummyModel.model_validate(model.model_dump()) == model

model = DummyModel(dtype=torch.float16)
assert model.model_dump()["dtype"] == "torch.float16"
assert DummyModel.model_validate(model.model_dump()) == model

model = DummyModel()
model.dtype = torch.float16
assert model.model_dump()["dtype"] == "torch.float16"
assert DummyModel.model_validate(model.model_dump()) == model


@pytest.mark.unit
def test_deserialization():
dummy_dict = {"dtype": "torch.float16"}
assert DummyModel.model_validate(dummy_dict).dtype == torch.float16

dummy_dict = {"dtype": "float16"}
assert DummyModel.model_validate(dummy_dict).dtype == torch.float16

with pytest.raises(ValueError):
dummy_dict = {"dtype": "notatype"}
DummyModel.model_validate(dummy_dict)

with pytest.raises(ValueError):
dummy_dict = {"dtype": "torch.notatype"}
DummyModel.model_validate(dummy_dict)