Skip to content

Error when configs are created with unrecognized fields #386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/compressed_tensors/quantization/quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
from compressed_tensors.utils import Aliasable
from compressed_tensors.utils.helpers import deprecated
from pydantic import BaseModel, Field, field_validator, model_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator


__all__ = [
Expand Down Expand Up @@ -358,6 +358,8 @@ def pytorch_dtype(self) -> torch.dtype:
def get_observer(self) -> str:
return self.observer

model_config = ConfigDict(extra="forbid")


def round_to_quantized_type(
tensor: torch.Tensor, args: QuantizationArgs
Expand Down
9 changes: 7 additions & 2 deletions src/compressed_tensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from enum import Enum
from typing import Dict, List, Optional, Union
from typing import Annotated, Any, Dict, List, Optional, Union

from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization.quant_args import DynamicType, QuantizationArgs
Expand All @@ -26,7 +26,7 @@
module_type,
parse_out_kv_cache_args,
)
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
from torch.nn import Module


Expand Down Expand Up @@ -142,6 +142,9 @@ class QuantizationConfig(BaseModel):
quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
global_compression_ratio: Optional[float] = None
ignore: Optional[List[str]] = Field(default_factory=list)
# `run_compressed` is a dummy, unused arg for backwards compatibility
# see: https://github.com/huggingface/transformers/pull/39324
run_compressed: Annotated[Any, Field(exclude=True)] = None

def model_post_init(self, __context):
"""
Expand Down Expand Up @@ -254,3 +257,5 @@ def requires_calibration_data(self):
return True

return False

model_config = ConfigDict(extra="forbid")
6 changes: 4 additions & 2 deletions src/compressed_tensors/quantization/quant_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@

import warnings
from copy import deepcopy
from typing import Any, Dict, List, Optional
from typing import List, Optional

from compressed_tensors.quantization.quant_args import (
DynamicType,
QuantizationArgs,
QuantizationStrategy,
QuantizationType,
)
from pydantic import BaseModel, model_validator
from pydantic import BaseModel, ConfigDict, model_validator


__all__ = [
Expand Down Expand Up @@ -81,6 +81,8 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":

return model

model_config = ConfigDict(extra="forbid")


"""
Pre-Set Quantization Scheme Args
Expand Down
2 changes: 1 addition & 1 deletion src/compressed_tensors/transform/factory/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from abc import ABC, abstractmethod
from collections import defaultdict
from typing import List, Optional, Tuple, Set
from typing import List, Optional, Set, Tuple

import torch
import torch.nn.utils.parametrize as P
Expand Down
4 changes: 3 additions & 1 deletion src/compressed_tensors/transform/transform_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from enum import Enum
from typing import List

from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator


__all__ = ["TransformArgs", "TransformLocation"]
Expand Down Expand Up @@ -74,3 +74,5 @@ def is_online(self) -> bool:
TransformLocation.WEIGHT_INPUT,
TransformLocation.WEIGHT_OUTPUT,
)

model_config = ConfigDict(extra="forbid")
42 changes: 2 additions & 40 deletions src/compressed_tensors/transform/transform_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Dict

from compressed_tensors.transform import TransformArgs, TransformScheme
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict


__all__ = ["TransformConfig"]
Expand All @@ -32,42 +32,4 @@ class TransformConfig(BaseModel):

config_groups: Dict[str, TransformScheme]


# quip / quip sharp
QUIP = TransformConfig(
config_groups={
"v": TransformScheme(
type="hadamard",
apply=[
TransformArgs(
targets=["Linear"],
location="input", # non-mergable
),
TransformArgs(
targets=["Linear"],
location="weight_input",
inverse=True,
),
],
randomize=True,
),
"u": TransformScheme(
type="hadamard",
apply=[
TransformArgs(
targets=["Linear"],
location="weight_output",
),
TransformArgs(
targets=["Linear"], location="output", inverse=True # non-mergable
),
],
randomize=True,
),
}
)


PRESET_CONFIGS = {
"QUIP": QUIP,
}
model_config = ConfigDict(extra="forbid")
4 changes: 3 additions & 1 deletion src/compressed_tensors/transform/transform_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
from compressed_tensors.transform import TransformArgs
from compressed_tensors.utils import TorchDtype
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field


__all__ = ["TransformScheme"]
Expand Down Expand Up @@ -46,3 +46,5 @@ class TransformScheme(BaseModel):
requires_grad: bool = Field(default=False)
head_dim: Optional[int] = Field(default=None)
precision: TorchDtype = Field(default=torch.float32)

model_config = ConfigDict(extra="forbid")
2 changes: 1 addition & 1 deletion tests/test_transform/factory/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_memory_sharing(type, randomize, requires_grad, offload=False):
config_groups={
"": TransformScheme(
type=type,
randomzied=randomize,
randomize=randomize,
requires_grad=requires_grad,
apply=[
TransformArgs(targets="Linear", location="input"),
Expand Down