Skip to content

Commit d418aea

Browse files
authored
Error when configs are created with unrecognized fields (#386)
* forbid extra on models, add tests Signed-off-by: Kyle Sayers <[email protected]> * fix tests Signed-off-by: Kyle Sayers <[email protected]> * add dummy arg for backwards compatibility Signed-off-by: Kyle Sayers <[email protected]> * exclude from equality checks Signed-off-by: Kyle Sayers <[email protected]> * fix test Signed-off-by: Kyle Sayers <[email protected]> --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent 0731aa5 commit d418aea

File tree

7 files changed

+23
-48
lines changed

7 files changed

+23
-48
lines changed

src/compressed_tensors/quantization/quant_args.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torch
2020
from compressed_tensors.utils import Aliasable
2121
from compressed_tensors.utils.helpers import deprecated
22-
from pydantic import BaseModel, Field, field_validator, model_validator
22+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
2323

2424

2525
__all__ = [
@@ -358,6 +358,8 @@ def pytorch_dtype(self) -> torch.dtype:
358358
def get_observer(self) -> str:
359359
return self.observer
360360

361+
model_config = ConfigDict(extra="forbid")
362+
361363

362364
def round_to_quantized_type(
363365
tensor: torch.Tensor, args: QuantizationArgs

src/compressed_tensors/quantization/quant_config.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from enum import Enum
16-
from typing import Dict, List, Optional, Union
16+
from typing import Annotated, Any, Dict, List, Optional, Union
1717

1818
from compressed_tensors.config import CompressionFormat
1919
from compressed_tensors.quantization.quant_args import DynamicType, QuantizationArgs
@@ -26,7 +26,7 @@
2626
module_type,
2727
parse_out_kv_cache_args,
2828
)
29-
from pydantic import BaseModel, Field
29+
from pydantic import BaseModel, ConfigDict, Field
3030
from torch.nn import Module
3131

3232

@@ -142,6 +142,9 @@ class QuantizationConfig(BaseModel):
142142
quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
143143
global_compression_ratio: Optional[float] = None
144144
ignore: Optional[List[str]] = Field(default_factory=list)
145+
# `run_compressed` is a dummy, unused arg for backwards compatibility
146+
# see: https://github.com/huggingface/transformers/pull/39324
147+
run_compressed: Annotated[Any, Field(exclude=True)] = None
145148

146149
def model_post_init(self, __context):
147150
"""
@@ -254,3 +257,5 @@ def requires_calibration_data(self):
254257
return True
255258

256259
return False
260+
261+
model_config = ConfigDict(extra="forbid")

src/compressed_tensors/quantization/quant_scheme.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414

1515
import warnings
1616
from copy import deepcopy
17-
from typing import Any, Dict, List, Optional
17+
from typing import List, Optional
1818

1919
from compressed_tensors.quantization.quant_args import (
2020
DynamicType,
2121
QuantizationArgs,
2222
QuantizationStrategy,
2323
QuantizationType,
2424
)
25-
from pydantic import BaseModel, model_validator
25+
from pydantic import BaseModel, ConfigDict, model_validator
2626

2727

2828
__all__ = [
@@ -81,6 +81,8 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
8181

8282
return model
8383

84+
model_config = ConfigDict(extra="forbid")
85+
8486

8587
"""
8688
Pre-Set Quantization Scheme Args

src/compressed_tensors/transform/transform_args.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from enum import Enum
1616
from typing import List
1717

18-
from pydantic import BaseModel, Field, field_validator
18+
from pydantic import BaseModel, ConfigDict, Field, field_validator
1919

2020

2121
__all__ = ["TransformArgs", "TransformLocation"]
@@ -74,3 +74,5 @@ def is_online(self) -> bool:
7474
TransformLocation.WEIGHT_INPUT,
7575
TransformLocation.WEIGHT_OUTPUT,
7676
)
77+
78+
model_config = ConfigDict(extra="forbid")

src/compressed_tensors/transform/transform_config.py

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from typing import Dict
1616

1717
from compressed_tensors.transform import TransformArgs, TransformScheme
18-
from pydantic import BaseModel
18+
from pydantic import BaseModel, ConfigDict
1919

2020

2121
__all__ = ["TransformConfig"]
@@ -32,42 +32,4 @@ class TransformConfig(BaseModel):
3232

3333
config_groups: Dict[str, TransformScheme]
3434

35-
36-
# quip / quip sharp
37-
QUIP = TransformConfig(
38-
config_groups={
39-
"v": TransformScheme(
40-
type="hadamard",
41-
apply=[
42-
TransformArgs(
43-
targets=["Linear"],
44-
location="input", # non-mergable
45-
),
46-
TransformArgs(
47-
targets=["Linear"],
48-
location="weight_input",
49-
inverse=True,
50-
),
51-
],
52-
randomize=True,
53-
),
54-
"u": TransformScheme(
55-
type="hadamard",
56-
apply=[
57-
TransformArgs(
58-
targets=["Linear"],
59-
location="weight_output",
60-
),
61-
TransformArgs(
62-
targets=["Linear"], location="output", inverse=True # non-mergable
63-
),
64-
],
65-
randomize=True,
66-
),
67-
}
68-
)
69-
70-
71-
PRESET_CONFIGS = {
72-
"QUIP": QUIP,
73-
}
35+
model_config = ConfigDict(extra="forbid")

src/compressed_tensors/transform/transform_scheme.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch
1818
from compressed_tensors.transform import TransformArgs
1919
from compressed_tensors.utils import TorchDtype
20-
from pydantic import BaseModel, Field
20+
from pydantic import BaseModel, ConfigDict, Field
2121

2222

2323
__all__ = ["TransformScheme"]
@@ -46,3 +46,5 @@ class TransformScheme(BaseModel):
4646
requires_grad: bool = Field(default=False)
4747
head_dim: Optional[int] = Field(default=None)
4848
precision: TorchDtype = Field(default=torch.float32)
49+
50+
model_config = ConfigDict(extra="forbid")

tests/test_transform/factory/test_memory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_memory_sharing(type, randomize, requires_grad, offload=False):
4242
config_groups={
4343
"": TransformScheme(
4444
type=type,
45-
randomzied=randomize,
45+
randomize=randomize,
4646
requires_grad=requires_grad,
4747
apply=[
4848
TransformArgs(targets="Linear", location="input"),

0 commit comments

Comments
 (0)