Skip to content

Commit 4ef531a

Browse files
author
liord
committed
update schema v2 with pydantic v2
1 parent 07c18a3 commit 4ef531a

File tree

1 file changed

+9
-19
lines changed
  • model_compression_toolkit/target_platform_capabilities/schema

1 file changed

+9
-19
lines changed

model_compression_toolkit/target_platform_capabilities/schema/v2.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,14 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
import pprint
16-
from enum import Enum
1716
from typing import Dict, Any, Tuple, Optional
1817

19-
from pydantic import BaseModel, root_validator
18+
from pydantic import BaseModel, model_validator, ConfigDict
2019

21-
from mct_quantizers import QuantizationMethod
22-
from model_compression_toolkit.constants import FLOAT_BITWIDTH
2320
from model_compression_toolkit.logger import Logger
2421
from model_compression_toolkit.target_platform_capabilities.schema.v1 import (
25-
Signedness,
26-
AttributeQuantizationConfig,
27-
OpQuantizationConfig,
2822
QuantizationConfigOptions,
29-
TargetPlatformModelComponent,
30-
OperatorsSetBase,
3123
OperatorsSet,
32-
OperatorSetGroup,
3324
Fusing,
3425
OperatorSetNames)
3526

@@ -62,27 +53,26 @@ class TargetPlatformCapabilities(BaseModel):
6253

6354
SCHEMA_VERSION: int = 2
6455

65-
class Config:
66-
frozen = True
56+
model_config = ConfigDict(frozen=True)
6757

68-
@root_validator(allow_reuse=True)
69-
def validate_after_initialization(cls, values: Dict[str, Any]) -> Dict[str, Any]:
58+
@model_validator(mode="after")
59+
def validate_after_initialization(cls, model: 'TargetPlatformCapabilities') -> Any:
7060
"""
7161
Perform validation after the model has been instantiated.
7262
7363
Args:
74-
values (Dict[str, Any]): The instantiated target platform model.
64+
model (TargetPlatformCapabilities): The instantiated target platform model.
7565
7666
Returns:
77-
Dict[str, Any]: The validated values.
67+
TargetPlatformCapabilities: The validated model.
7868
"""
7969
# Validate `default_qco`
80-
default_qco = values.get('default_qco')
70+
default_qco = model.default_qco
8171
if len(default_qco.quantization_configurations) != 1:
8272
Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") # pragma: no cover
8373

8474
# Validate `operator_set` uniqueness
85-
operator_set = values.get('operator_set')
75+
operator_set = model.operator_set
8676
if operator_set is not None:
8777
opsets_names = [
8878
op.name.value if isinstance(op.name, OperatorSetNames) else op.name
@@ -91,7 +81,7 @@ def validate_after_initialization(cls, values: Dict[str, Any]) -> Dict[str, Any]
9181
if len(set(opsets_names)) != len(opsets_names):
9282
Logger.critical("Operator Sets must have unique names.") # pragma: no cover
9383

94-
return values
84+
return model
9585

9686
def get_info(self) -> Dict[str, Any]:
9787
"""

0 commit comments

Comments
 (0)