1313# limitations under the License.
1414# ==============================================================================
1515import pprint
16- from enum import Enum
1716from typing import Dict , Any , Tuple , Optional
1817
19- from pydantic import BaseModel , root_validator
18+ from pydantic import BaseModel , model_validator , ConfigDict
2019
21- from mct_quantizers import QuantizationMethod
22- from model_compression_toolkit .constants import FLOAT_BITWIDTH
2320from model_compression_toolkit .logger import Logger
2421from model_compression_toolkit .target_platform_capabilities .schema .v1 import (
25- Signedness ,
26- AttributeQuantizationConfig ,
27- OpQuantizationConfig ,
2822 QuantizationConfigOptions ,
29- TargetPlatformModelComponent ,
30- OperatorsSetBase ,
3123 OperatorsSet ,
32- OperatorSetGroup ,
3324 Fusing ,
3425 OperatorSetNames )
3526
@@ -62,27 +53,26 @@ class TargetPlatformCapabilities(BaseModel):
6253
6354 SCHEMA_VERSION : int = 2
6455
65- class Config :
66- frozen = True
56+ model_config = ConfigDict (frozen = True )
6757
68- @root_validator ( allow_reuse = True )
69- def validate_after_initialization (cls , values : Dict [ str , Any ] ) -> Dict [ str , Any ] :
58+ @model_validator ( mode = "after" )
59+ def validate_after_initialization (cls , model : 'TargetPlatformCapabilities' ) -> Any :
7060 """
7161 Perform validation after the model has been instantiated.
7262
7363 Args:
74- values (Dict[str, Any] ): The instantiated target platform model.
64+ model (TargetPlatformCapabilities ): The instantiated target platform model.
7565
7666 Returns:
77- Dict[str, Any] : The validated values .
67+ TargetPlatformCapabilities : The validated model .
7868 """
7969 # Validate `default_qco`
80- default_qco = values . get ( ' default_qco' )
70+ default_qco = model . default_qco
8171 if len (default_qco .quantization_configurations ) != 1 :
8272 Logger .critical ("Default QuantizationConfigOptions must contain exactly one option." ) # pragma: no cover
8373
8474 # Validate `operator_set` uniqueness
85- operator_set = values . get ( ' operator_set' )
75+ operator_set = model . operator_set
8676 if operator_set is not None :
8777 opsets_names = [
8878 op .name .value if isinstance (op .name , OperatorSetNames ) else op .name
@@ -91,7 +81,7 @@ def validate_after_initialization(cls, values: Dict[str, Any]) -> Dict[str, Any]
9181 if len (set (opsets_names )) != len (opsets_names ):
9282 Logger .critical ("Operator Sets must have unique names." ) # pragma: no cover
9383
94- return values
84+ return model
9585
9686 def get_info (self ) -> Dict [str , Any ]:
9787 """
0 commit comments