Skip to content

Commit c991db3

Browse files
committed
forbid extra on models, add tests
Signed-off-by: Kyle Sayers <[email protected]>
1 parent b163bd9 commit c991db3

File tree

10 files changed

+27
-15
lines changed

10 files changed

+27
-15
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def compress_model(self, model: Module):
393393

394394
if prefix in module_to_scheme or prefix in sparse_compression_targets:
395395
module_device = get_execution_device(module).type
396-
is_meta = (module_device == "meta")
396+
is_meta = module_device == "meta"
397397

398398
exec_device = "meta" if is_meta else "cpu"
399399
onloading_device = "meta" if is_meta else module_device

src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,13 @@ def sparse24_bitmask_compress(
178178

179179
if tensor.is_meta:
180180
num_rows, num_cols = tensor.shape
181-
compressed_values = torch.empty((num_rows, num_cols // 2), dtype=tensor.dtype, device="meta")
181+
compressed_values = torch.empty(
182+
(num_rows, num_cols // 2), dtype=tensor.dtype, device="meta"
183+
)
182184
packed_cols = (num_cols + 7) // 8
183-
bitmasks_packed = torch.empty((num_rows, packed_cols), dtype=torch.uint8, device="meta")
185+
bitmasks_packed = torch.empty(
186+
(num_rows, packed_cols), dtype=torch.uint8, device="meta"
187+
)
184188
return compressed_values, bitmasks_packed
185189

186190
bytemasks = get_24_bytemasks(tensor=tensor)

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,12 @@ def _initialize_scale_zero_point(
189189
else:
190190
# TODO: consider erroring out in the future as if the dtype if not one of these,
191191
# there is likely bug
192-
if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32, torch.float64]:
192+
if scale_dtype not in [
193+
torch.float16,
194+
torch.bfloat16,
195+
torch.float32,
196+
torch.float64,
197+
]:
193198
scale_dtype = torch.float16
194199
zp_dtype = quantization_args.pytorch_dtype()
195200

src/compressed_tensors/quantization/quant_args.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torch
2020
from compressed_tensors.utils import Aliasable
2121
from compressed_tensors.utils.helpers import deprecated
22-
from pydantic import BaseModel, Field, field_validator, model_validator
22+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
2323

2424

2525
__all__ = [
@@ -186,6 +186,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
186186
"Observers constructor excluding quantization range or symmetry"
187187
),
188188
)
189+
model_config = ConfigDict(extra="forbid")
189190

190191
@field_validator("type", mode="before")
191192
def validate_type(cls, value) -> QuantizationType:

src/compressed_tensors/quantization/quant_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
module_type,
2727
parse_out_kv_cache_args,
2828
)
29-
from pydantic import BaseModel, Field
29+
from pydantic import BaseModel, ConfigDict, Field
3030
from torch.nn import Module
3131

3232

@@ -142,6 +142,7 @@ class QuantizationConfig(BaseModel):
142142
quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
143143
global_compression_ratio: Optional[float] = None
144144
ignore: Optional[List[str]] = Field(default_factory=list)
145+
model_config = ConfigDict(extra="forbid")
145146

146147
def model_post_init(self, __context):
147148
"""

src/compressed_tensors/quantization/quant_scheme.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313
# limitations under the License.
1414

1515
from copy import deepcopy
16-
from typing import Any, Dict, List, Optional
16+
from typing import List, Optional
1717

1818
from compressed_tensors.quantization.quant_args import (
1919
DynamicType,
2020
QuantizationArgs,
2121
QuantizationStrategy,
2222
QuantizationType,
2323
)
24-
from pydantic import BaseModel, model_validator
24+
from pydantic import BaseModel, ConfigDict, model_validator
2525

2626

2727
__all__ = [
@@ -47,6 +47,7 @@ class QuantizationScheme(BaseModel):
4747
weights: Optional[QuantizationArgs] = None
4848
input_activations: Optional[QuantizationArgs] = None
4949
output_activations: Optional[QuantizationArgs] = None
50+
model_config = ConfigDict(extra="forbid")
5051

5152
@model_validator(mode="after")
5253
def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":

src/compressed_tensors/transform/transform_args.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from enum import Enum
1616
from typing import List
1717

18-
from pydantic import BaseModel, Field, field_validator
18+
from pydantic import BaseModel, ConfigDict, Field, field_validator
1919

2020

2121
__all__ = ["TransformArgs", "TransformLocation"]
@@ -61,6 +61,7 @@ class TransformArgs(BaseModel):
6161
location: TransformLocation
6262
inverse: bool = Field(default=False)
6363
ignore: List[str] = Field(default_factory=list)
64+
model_config = ConfigDict(extra="forbid")
6465

6566
@field_validator("targets", "ignore", mode="before")
6667
@classmethod

src/compressed_tensors/transform/transform_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from typing import Dict
1616

1717
from compressed_tensors.transform import TransformArgs, TransformScheme
18-
from pydantic import BaseModel
18+
from pydantic import BaseModel, ConfigDict
1919

2020

2121
__all__ = ["TransformConfig"]
@@ -31,6 +31,7 @@ class TransformConfig(BaseModel):
3131
"""
3232

3333
config_groups: Dict[str, TransformScheme]
34+
model_config = ConfigDict(extra="forbid")
3435

3536

3637
# quip / quip sharp

src/compressed_tensors/transform/transform_scheme.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from typing import List
1616

1717
from compressed_tensors.transform import TransformArgs
18-
from pydantic import BaseModel, Field
18+
from pydantic import BaseModel, ConfigDict, Field
1919

2020

2121
__all__ = ["TransformScheme"]
@@ -40,3 +40,4 @@ class TransformScheme(BaseModel):
4040
apply: List[TransformArgs] = Field(default_factory=list)
4141
randomize: bool = Field(default=False)
4242
requires_grad: bool = Field(default=False)
43+
model_config = ConfigDict(extra="forbid")

tests/test_compressors/model_compressors/test_model_compressor.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -446,10 +446,7 @@ def test_compress_model_meta(model_stub, q_format, s_config):
446446
cpu_model, s_config, q_format
447447
)
448448
# Only stores dtype because meta model does not store values
449-
expected = {
450-
k: v.dtype
451-
for k, v in reference_compressor.compress(cpu_model).items()
452-
}
449+
expected = {k: v.dtype for k, v in reference_compressor.compress(cpu_model).items()}
453450

454451
# Load model on meta device
455452
meta_model = AutoModelForCausalLM.from_pretrained(

0 commit comments

Comments
 (0)