Skip to content

Commit 1fa514a

Browse files
authored
replace depreciated pydantic functions (#221)
1 parent 5771dec commit 1fa514a

File tree

6 files changed

+15
-8
lines changed

6 files changed

+15
-8
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def from_compression_config(
137137
format, **sparsity_config
138138
)
139139
if quantization_config is not None:
140-
quantization_config = QuantizationConfig.parse_obj(quantization_config)
140+
quantization_config = QuantizationConfig.model_validate(quantization_config)
141141

142142
return cls(
143143
sparsity_config=sparsity_config, quantization_config=quantization_config
@@ -193,7 +193,7 @@ def parse_sparsity_config(
193193

194194
if is_compressed_tensors_config(compression_config):
195195
s_config = compression_config.sparsity_config
196-
return s_config.dict() if s_config is not None else None
196+
return s_config.model_dump() if s_config is not None else None
197197

198198
return compression_config.get(SPARSITY_CONFIG_NAME, None)
199199

@@ -214,7 +214,7 @@ def parse_quantization_config(
214214

215215
if is_compressed_tensors_config(compression_config):
216216
q_config = compression_config.quantization_config
217-
return q_config.dict() if q_config is not None else None
217+
return q_config.model_dump() if q_config is not None else None
218218

219219
quantization_config = deepcopy(compression_config)
220220
quantization_config.pop(SPARSITY_CONFIG_NAME, None)

src/compressed_tensors/quantization/quant_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def model_post_init(self, __context):
160160

161161
def to_dict(self):
162162
# for compatibility with HFQuantizer
163-
return self.dict()
163+
return self.model_dump()
164164

165165
@staticmethod
166166
def from_pretrained(

tests/test_compressors/model_compressors/test_model_compressor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ def test_hf_compressor_tensors_config(s_config, q_config, tmp_path):
9999
)
100100
q_config = QuantizationConfig(**q_config) if q_config is not None else None
101101

102-
s_config_dict = s_config.dict() if s_config is not None else None
103-
q_config_dict = q_config.dict() if q_config is not None else None
102+
s_config_dict = s_config.model_dump() if s_config is not None else None
103+
q_config_dict = q_config.model_dump() if q_config is not None else None
104104

105105
assert compressor.sparsity_config == s_config
106106
assert compressor.quantization_config == q_config

tests/test_quantization/lifecycle/test_apply.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def get_sample_tinyllama_quant_config(status: str = "frozen"):
222222
},
223223
"ignore": ["LlamaRotaryEmbedding", "model.layers.1.mlp.down_proj"],
224224
}
225-
return QuantizationConfig.parse_obj(config_dict)
225+
return QuantizationConfig.model_validate(config_dict)
226226

227227

228228
@requires_accelerate()

tests/test_quantization/lifecycle/test_dynamic_lifecycle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,4 +110,4 @@ def get_sample_dynamic_tinyllama_quant_config():
110110
},
111111
"ignore": ["LlamaRotaryEmbedding", "model.layers.1.mlp.down_proj"],
112112
}
113-
return QuantizationConfig.parse_obj(config_dict)
113+
return QuantizationConfig.model_validate(config_dict)

tests/test_quantization/test_quant_config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,10 @@ def test_load_scheme_from_preset(scheme_name: str):
7272
assert scheme_name in config.config_groups
7373
assert isinstance(config.config_groups[scheme_name], QuantizationScheme)
7474
assert config.config_groups[scheme_name].targets == targets
75+
76+
77+
def test_to_dict():
78+
config_groups = {"group_1": QuantizationScheme(targets=[])}
79+
config = QuantizationConfig(config_groups=config_groups)
80+
reloaded = QuantizationConfig.model_validate(config.to_dict())
81+
assert config == reloaded

0 commit comments

Comments
 (0)