diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 53dbf88e..6c898429 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -19,7 +19,7 @@ import torch from compressed_tensors.utils import Aliasable from compressed_tensors.utils.helpers import deprecated -from pydantic import BaseModel, Field, field_validator, model_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator __all__ = [ @@ -358,6 +358,8 @@ def pytorch_dtype(self) -> torch.dtype: def get_observer(self) -> str: return self.observer + model_config = ConfigDict(extra="forbid") + def round_to_quantized_type( tensor: torch.Tensor, args: QuantizationArgs diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 36ed1982..db27caf8 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -13,7 +13,7 @@ # limitations under the License. from enum import Enum -from typing import Dict, List, Optional, Union +from typing import Annotated, Any, Dict, List, Optional, Union from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization.quant_args import DynamicType, QuantizationArgs @@ -26,7 +26,7 @@ module_type, parse_out_kv_cache_args, ) -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from torch.nn import Module @@ -142,6 +142,9 @@ class QuantizationConfig(BaseModel): quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED global_compression_ratio: Optional[float] = None ignore: Optional[List[str]] = Field(default_factory=list) + # `run_compressed` is a dummy, unused arg for backwards compatibility + # see: https://github.com/huggingface/transformers/pull/39324 + run_compressed: Annotated[Any, Field(exclude=True)] = None def model_post_init(self, __context): """ @@ -254,3 +257,5 @@ def requires_calibration_data(self): return True return False + + model_config = ConfigDict(extra="forbid") diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 1124b55f..29864d25 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -14,7 +14,7 @@ import warnings from copy import deepcopy -from typing import Any, Dict, List, Optional +from typing import List, Optional from compressed_tensors.quantization.quant_args import ( DynamicType, @@ -22,7 +22,7 @@ QuantizationStrategy, QuantizationType, ) -from pydantic import BaseModel, model_validator +from pydantic import BaseModel, ConfigDict, model_validator __all__ = [ @@ -81,6 +81,8 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": return model + model_config = ConfigDict(extra="forbid") + """ Pre-Set Quantization Scheme Args diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index 2218bd30..89573d11 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -14,7 +14,7 @@ from abc import ABC, abstractmethod from collections import defaultdict -from typing import List, Optional, Tuple, Set +from typing import List, Optional, Set, Tuple import torch import torch.nn.utils.parametrize as P diff --git a/src/compressed_tensors/transform/transform_args.py b/src/compressed_tensors/transform/transform_args.py index 8d2f8973..d3f46957 100644 --- a/src/compressed_tensors/transform/transform_args.py +++ b/src/compressed_tensors/transform/transform_args.py @@ -15,7 +15,7 @@ from enum import Enum from typing import List -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator __all__ = ["TransformArgs", "TransformLocation"] @@ -74,3 +74,5 @@ def is_online(self) -> bool: TransformLocation.WEIGHT_INPUT, TransformLocation.WEIGHT_OUTPUT, ) + + model_config = ConfigDict(extra="forbid") diff --git a/src/compressed_tensors/transform/transform_config.py b/src/compressed_tensors/transform/transform_config.py index df178c42..e4c6bad4 100644 --- a/src/compressed_tensors/transform/transform_config.py +++ b/src/compressed_tensors/transform/transform_config.py @@ -15,7 +15,7 @@ from typing import Dict from compressed_tensors.transform import TransformArgs, TransformScheme -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict __all__ = ["TransformConfig"] @@ -32,42 +32,4 @@ class TransformConfig(BaseModel): config_groups: Dict[str, TransformScheme] - -# quip / quip sharp -QUIP = TransformConfig( - config_groups={ - "v": TransformScheme( - type="hadamard", - apply=[ - TransformArgs( - targets=["Linear"], - location="input", # non-mergable - ), - TransformArgs( - targets=["Linear"], - location="weight_input", - inverse=True, - ), - ], - randomize=True, - ), - "u": TransformScheme( - type="hadamard", - apply=[ - TransformArgs( - targets=["Linear"], - location="weight_output", - ), - TransformArgs( - targets=["Linear"], location="output", inverse=True # non-mergable - ), - ], - randomize=True, - ), - } -) - - -PRESET_CONFIGS = { - "QUIP": QUIP, -} + model_config = ConfigDict(extra="forbid") diff --git a/src/compressed_tensors/transform/transform_scheme.py b/src/compressed_tensors/transform/transform_scheme.py index 7236b9dd..2f23b17e 100644 --- a/src/compressed_tensors/transform/transform_scheme.py +++ b/src/compressed_tensors/transform/transform_scheme.py @@ -17,7 +17,7 @@ import torch from compressed_tensors.transform import TransformArgs from compressed_tensors.utils import TorchDtype -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field __all__ = ["TransformScheme"] @@ -46,3 +46,5 @@ class TransformScheme(BaseModel): requires_grad: bool = Field(default=False) head_dim: Optional[int] = Field(default=None) precision: TorchDtype = Field(default=torch.float32) + + model_config = ConfigDict(extra="forbid") diff --git a/tests/test_transform/factory/test_memory.py b/tests/test_transform/factory/test_memory.py index 7fc3c914..64a068c9 100644 --- a/tests/test_transform/factory/test_memory.py +++ b/tests/test_transform/factory/test_memory.py @@ -42,7 +42,7 @@ def test_memory_sharing(type, randomize, requires_grad, offload=False): config_groups={ "": TransformScheme( type=type, - randomzied=randomize, + randomize=randomize, requires_grad=requires_grad, apply=[ TransformArgs(targets="Linear", location="input"),