Skip to content

Commit ac049db

Browse files
committed
Merge main into refactor_utils
2 parents 1520a25 + b78eb8f commit ac049db

File tree

24 files changed

+520
-163
lines changed

24 files changed

+520
-163
lines changed

pyproject.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,12 @@ build-backend = "setuptools.build_meta"
55
[tool.black]
66
line-length = 88
77
target-version = ['py36']
8+
9+
[tool.pytest.ini_options]
10+
markers = [
11+
"unit: tests to ensure code correctness and regression test functionality",
12+
"smoke: quick tests to check basic functionality",
13+
"sanity: tests to ensure that new changes do not break existing functionality",
14+
"regression: detailed tests to ensure major functions work correctly",
15+
"integration: tests which integrate with a third party service such as HF",
16+
]

src/compressed_tensors/base.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
SPARSITY_CONFIG_NAME = "sparsity_config"
15+
# configs
1616
QUANTIZATION_CONFIG_NAME = "quantization_config"
17-
COMPRESSION_CONFIG_NAME = "compression_config"
18-
KV_CACHE_SCHEME_NAME = "kv_cache_scheme"
17+
SPARSITY_CONFIG_NAME = "sparsity_config"
18+
TRANSFORM_CONFIG_NAME = "transform_config"
19+
20+
# required fields
1921
COMPRESSION_VERSION_NAME = "version"
2022
QUANTIZATION_METHOD_NAME = "quant_method"
23+
24+
# auxillary configs
25+
KV_CACHE_SCHEME_NAME = "kv_cache_scheme"

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 149 additions & 60 deletions
Large diffs are not rendered by default.

src/compressed_tensors/config/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class CompressionFormat(Enum):
3232
naive_quantized = "naive-quantized"
3333
pack_quantized = "pack-quantized"
3434
marlin_24 = "marlin-24"
35+
mixed_precision = "mixed-precision"
3536
nvfp4_pack_quantized = "nvfp4-pack-quantized"
3637

3738

src/compressed_tensors/quantization/quant_args.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torch
2020
from compressed_tensors.utils import Aliasable
2121
from compressed_tensors.utils.helpers import deprecated
22-
from pydantic import BaseModel, Field, field_validator, model_validator
22+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
2323

2424

2525
__all__ = [
@@ -358,6 +358,8 @@ def pytorch_dtype(self) -> torch.dtype:
358358
def get_observer(self) -> str:
359359
return self.observer
360360

361+
model_config = ConfigDict(extra="forbid")
362+
361363

362364
def round_to_quantized_type(
363365
tensor: torch.Tensor, args: QuantizationArgs

src/compressed_tensors/quantization/quant_config.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from enum import Enum
16-
from typing import Dict, List, Optional, Union
16+
from typing import Annotated, Any, Dict, List, Optional, Union
1717

1818
from compressed_tensors.config import CompressionFormat
1919
from compressed_tensors.quantization.quant_args import DynamicType, QuantizationArgs
@@ -26,7 +26,7 @@
2626
module_type,
2727
parse_out_kv_cache_args,
2828
)
29-
from pydantic import BaseModel, Field
29+
from pydantic import BaseModel, ConfigDict, Field
3030
from torch.nn import Module
3131

3232

@@ -142,6 +142,9 @@ class QuantizationConfig(BaseModel):
142142
quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
143143
global_compression_ratio: Optional[float] = None
144144
ignore: Optional[List[str]] = Field(default_factory=list)
145+
# `run_compressed` is a dummy, unused arg for backwards compatibility
146+
# see: https://github.com/huggingface/transformers/pull/39324
147+
run_compressed: Annotated[Any, Field(exclude=True)] = None
145148

146149
def model_post_init(self, __context):
147150
"""
@@ -231,6 +234,12 @@ def from_pretrained(
231234
format = CompressionFormat.int_quantized.value
232235
else:
233236
format = CompressionFormat.dense.value
237+
elif isinstance(format, list):
238+
format = (
239+
CompressionFormat.mixed_precision.value
240+
if len(format) > 1
241+
else format[0]
242+
)
234243

235244
return QuantizationConfig(
236245
config_groups=config_groups,
@@ -254,3 +263,6 @@ def requires_calibration_data(self):
254263
return True
255264

256265
return False
266+
267+
# TODO set `extra="forbid"` when upstream transformers is compatible
268+
model_config = ConfigDict(extra="ignore")

src/compressed_tensors/quantization/quant_scheme.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,16 @@
1414

1515
import warnings
1616
from copy import deepcopy
17-
from typing import Any, Dict, List, Optional
17+
from typing import List, Optional
1818

19+
from compressed_tensors.config import CompressionFormat
1920
from compressed_tensors.quantization.quant_args import (
2021
DynamicType,
2122
QuantizationArgs,
2223
QuantizationStrategy,
2324
QuantizationType,
2425
)
25-
from pydantic import BaseModel, model_validator
26+
from pydantic import BaseModel, ConfigDict, model_validator
2627

2728

2829
__all__ = [
@@ -42,18 +43,21 @@ class QuantizationScheme(BaseModel):
4243
:param weights: quantization config for layer weights
4344
:param input_activations: quantization config for layer inputs
4445
:param output_activations: quantization config for layer outputs
46+
:param format: CompressionFormat for the layer
4547
"""
4648

4749
targets: List[str]
4850
weights: Optional[QuantizationArgs] = None
4951
input_activations: Optional[QuantizationArgs] = None
5052
output_activations: Optional[QuantizationArgs] = None
53+
format: Optional[str] = None
5154

5255
@model_validator(mode="after")
5356
def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
5457
inputs = model.input_activations
5558
outputs = model.output_activations
5659
weights = model.weights
60+
format = model.format
5761

5862
if inputs is not None:
5963
if inputs.actorder is not None:
@@ -63,6 +67,11 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
6367
if outputs.actorder is not None:
6468
raise ValueError("Cannot apply actorder to output activations")
6569

70+
if format == CompressionFormat.mixed_precision.value:
71+
raise ValueError(
72+
"mixed-precision cannot be set as a format for a QuantizationScheme"
73+
)
74+
6675
if (
6776
inputs
6877
and weights
@@ -81,6 +90,8 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
8190

8291
return model
8392

93+
model_config = ConfigDict(extra="forbid")
94+
8495

8596
"""
8697
Pre-Set Quantization Scheme Args

src/compressed_tensors/transform/apply.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import torch
16+
from compressed_tensors import TRANSFORM_CONFIG_NAME
1617
from compressed_tensors.transform import TransformConfig, TransformFactory
1718

1819

@@ -30,3 +31,6 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig):
3031
for name, scheme in config.config_groups.items():
3132
factory = TransformFactory.from_scheme(scheme, name=name)
3233
factory.apply_to_model(model)
34+
35+
# attach config to model for compression/serialization
36+
setattr(model, TRANSFORM_CONFIG_NAME, config)

src/compressed_tensors/transform/factory/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,10 @@
1414

1515
from abc import ABC, abstractmethod
1616
from collections import defaultdict
17-
from typing import List, Optional, Tuple, Set
17+
from typing import List, Optional, Set, Tuple
1818

1919
import torch
2020
import torch.nn.utils.parametrize as P
21-
from compressed_tensors import InternalModule
2221
from compressed_tensors.registry.registry import RegistryMixin, T
2322
from compressed_tensors.transform import (
2423
TransformArgs,
@@ -34,6 +33,7 @@
3433
register_offload_module,
3534
update_offload_parameter,
3635
)
36+
from compressed_tensors.utils.internal import InternalModule
3737
from torch import Tensor
3838
from torch.nn import Module, Parameter
3939

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import math
16-
from typing import Optional, Union
15+
from typing import Optional
1716

1817
import torch
1918
from compressed_tensors.transform import TransformArgs, TransformScheme
@@ -26,7 +25,7 @@
2625
from compressed_tensors.utils import get_execution_device, get_offloaded_device
2726
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
2827
from torch import Tensor, device, dtype
29-
from torch.nn import Linear, Module, Parameter
28+
from torch.nn import Module, Parameter
3029

3130

3231
@TransformFactory.register("hadamard")
@@ -54,14 +53,14 @@ def create_transform(self, module: Module, args: TransformArgs):
5453
"""
5554
assert hasattr(module, "weight")
5655
size = get_transform_size(module, args.location, self.scheme.head_dim)
57-
dtype = module.weight.dtype
56+
dtype = self.scheme.precision
5857
device = get_offloaded_device(module)
5958
exec_device = get_execution_device(module)
6059

6160
factory_kwargs = {"construct_device": exec_device}
6261
weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
6362
perm = self.perms[weight] if self.scheme.randomize else None
64-
return HadamardTransform(weight, perm, args, type(module))
63+
return HadamardTransform(weight, perm, self.scheme, args, type(module))
6564

6665
def _create_weight(
6766
self,
@@ -85,15 +84,18 @@ def __init__(
8584
self,
8685
weight: Parameter,
8786
perm: Optional[Parameter],
87+
scheme: TransformScheme,
8888
args: TransformArgs,
8989
module_type: type[torch.nn.Module],
9090
):
9191
super().__init__()
9292
self.weight = weight
9393
self.perm = perm
94+
self.scheme = scheme
9495
self.args = args
9596
self.module_type = module_type
96-
self._scale = math.sqrt(weight.size(0))
97+
self._scale = torch.tensor(weight.size(0), dtype=self.scheme.precision).sqrt()
98+
self._precision = scheme.precision if args.is_online() else torch.float64
9799

98100
def forward(self, value: Tensor) -> Tensor:
99101
weight = self.weight
@@ -105,6 +107,11 @@ def forward(self, value: Tensor) -> Tensor:
105107
weight = weight.T
106108

107109
return (
108-
apply_transform_weight(weight, value, self.args.location, self.module_type)
110+
apply_transform_weight(
111+
weight.to(self._precision),
112+
value.to(self._precision),
113+
self.args.location,
114+
self.module_type,
115+
)
109116
/ self._scale
110-
)
117+
).to(value.dtype)

0 commit comments

Comments
 (0)