diff --git a/.github/workflows/run_tests_python310_keras212.yml b/.github/workflows/run_tests_python310_keras212.yml deleted file mode 100644 index 172faaa31..000000000 --- a/.github/workflows/run_tests_python310_keras212.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Python 3.10, Keras 2.12 -on: - workflow_dispatch: # Allow manual triggers - schedule: - - cron: 0 0 * * * - pull_request: - branches: - - main - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} - -jobs: - run-tests: - uses: ./.github/workflows/run_keras_tests.yml - with: - python-version: "3.10" - tf-version: "2.12.*" \ No newline at end of file diff --git a/.github/workflows/run_tests_python310_keras213.yml b/.github/workflows/run_tests_python310_keras213.yml deleted file mode 100644 index 7a90da443..000000000 --- a/.github/workflows/run_tests_python310_keras213.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Python 3.10, Keras 2.13 -on: - workflow_dispatch: # Allow manual triggers - schedule: - - cron: 0 0 * * * - pull_request: - branches: - - main - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} - -jobs: - run-tests: - uses: ./.github/workflows/run_keras_tests.yml - with: - python-version: "3.10" - tf-version: "2.13.*" \ No newline at end of file diff --git a/.github/workflows/run_tests_python311_keras212.yml b/.github/workflows/run_tests_python311_keras212.yml deleted file mode 100644 index b129a5843..000000000 --- a/.github/workflows/run_tests_python311_keras212.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Python 3.11, Keras 2.12 -on: - workflow_dispatch: # Allow manual triggers - schedule: - - cron: 0 0 * * * - pull_request: - branches: - - main - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} - -jobs: - run-tests: - uses: ./.github/workflows/run_keras_tests.yml - with: - python-version: "3.11" - tf-version: "2.12.*" \ No newline at end of file diff --git a/.github/workflows/run_tests_python311_keras213.yml b/.github/workflows/run_tests_python311_keras213.yml deleted file mode 100644 index 635534bfc..000000000 --- a/.github/workflows/run_tests_python311_keras213.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Python 3.11, Keras 2.13 -on: - workflow_dispatch: # Allow manual triggers - schedule: - - cron: 0 0 * * * - pull_request: - branches: - - main - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} - -jobs: - run-tests: - uses: ./.github/workflows/run_keras_tests.yml - with: - python-version: "3.11" - tf-version: "2.13.*" \ No newline at end of file diff --git a/.github/workflows/run_tests_python39_keras212.yml b/.github/workflows/run_tests_python39_keras212.yml deleted file mode 100644 index 5f42dd35b..000000000 --- a/.github/workflows/run_tests_python39_keras212.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Python 3.9, Keras 2.12 -on: - workflow_dispatch: # Allow manual triggers - schedule: - - cron: 0 0 * * * - pull_request: - branches: - - main - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} - -jobs: - run-tests: - uses: ./.github/workflows/run_keras_tests.yml - with: - python-version: "3.9" - tf-version: "2.12.*" \ No newline at end of file diff --git a/.github/workflows/run_tests_python39_keras213.yml b/.github/workflows/run_tests_python39_keras213.yml deleted file mode 100644 index fa7c34390..000000000 --- a/.github/workflows/run_tests_python39_keras213.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Python 3.9, Keras 2.13 -on: - workflow_dispatch: # Allow manual triggers - schedule: - - cron: 0 0 * * * - pull_request: - branches: - - main - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} - -jobs: - run-tests: - uses: ./.github/workflows/run_keras_tests.yml - with: - python-version: "3.9" - tf-version: "2.13.*" \ No newline at end of file diff --git a/.github/workflows/run_tests_suite_coverage.yml b/.github/workflows/run_tests_suite_coverage.yml index d8b4947bf..d916c0da9 100644 --- a/.github/workflows/run_tests_suite_coverage.yml +++ b/.github/workflows/run_tests_suite_coverage.yml @@ -65,7 +65,7 @@ jobs: python -m venv tf_env source tf_env/bin/activate python -m pip install --upgrade pip - pip install -r requirements.txt tensorflow==2.13.* coverage pytest pytest-mock + pip install -r requirements.txt tensorflow==2.15.* coverage pytest pytest-mock - name: Run TensorFlow tests (unittest) run: | diff --git a/README.md b/README.md index 40cbd3ae5..5caad0152 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ ________________________________________________________________________________ ##
Getting Started
### Quick Installation -Pip install the model compression toolkit package in a Python>=3.9 environment with PyTorch>=2.1 or Tensorflow>=2.12. +Pip install the model compression toolkit package in a Python>=3.9 environment with PyTorch>=2.1 or Tensorflow>=2.14. ``` pip install model-compression-toolkit ``` @@ -137,11 +137,11 @@ Currently, MCT is being tested on various Python, Pytorch and TensorFlow version | Python 3.11 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch22.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch22.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch23.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch23.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch24.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch24.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch25.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch25.yml) | | Python 3.12 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch22.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch22.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch23.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch23.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch24.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch24.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch25.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch25.yml) | -| | TensorFlow 2.12 | TensorFlow 2.13 | TensorFlow 2.14 | TensorFlow 2.15 | +| | TensorFlow 2.14 | TensorFlow 2.15 | |-------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| Python 3.9 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras212.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras212.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras213.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras213.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras214.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras215.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras215.yml) | -| Python 3.10 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras212.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras212.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras213.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras213.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras214.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras215.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras215.yml) | -| Python 3.11 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras212.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras212.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras213.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras213.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras214.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras215.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras215.yml) | +| Python 3.9 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras214.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras215.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras215.yml) | +| Python 3.10 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras214.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras215.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras215.yml) | +| Python 3.11 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras214.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras215.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras215.yml) | diff --git a/docsrc/source/conf.py b/docsrc/source/conf.py index c375580de..e2eff2bf4 100644 --- a/docsrc/source/conf.py +++ b/docsrc/source/conf.py @@ -10,6 +10,20 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # +import logging +import re + +# Define a filter that suppresses log records matching our unwanted warning message. +class IgnoreGuardedTypeImportFilter(logging.Filter): + def filter(self, record): + # Return False (i.e. ignore) if the message contains our specific text. + if re.search(r"Failed guarded type import with ImportError.*AbstractSetIntStr", record.getMessage()): + return False + return True + +# Attach the filter to the "sphinx" logger. +logger = logging.getLogger("sphinx") +logger.addFilter(IgnoreGuardedTypeImportFilter()) import os import sys diff --git a/model_compression_toolkit/target_platform_capabilities/schema/v1.py b/model_compression_toolkit/target_platform_capabilities/schema/v1.py index fbbf5a2fb..284b56cd2 100644 --- a/model_compression_toolkit/target_platform_capabilities/schema/v1.py +++ b/model_compression_toolkit/target_platform_capabilities/schema/v1.py @@ -16,7 +16,8 @@ from enum import Enum from typing import Dict, Any, Union, Tuple, List, Optional, Literal, Annotated -from pydantic import BaseModel, Field, root_validator, validator, PositiveInt +from pydantic import BaseModel, Field, root_validator, validator, PositiveInt, ConfigDict, field_validator, \ + model_validator from mct_quantizers import QuantizationMethod from model_compression_toolkit.constants import FLOAT_BITWIDTH @@ -118,9 +119,7 @@ class AttributeQuantizationConfig(BaseModel): enable_weights_quantization: bool = False lut_values_bitwidth: Optional[int] = None - class Config: - # Makes the model immutable (frozen) - frozen = True + model_config = ConfigDict(frozen=True) @property def field_names(self) -> list: @@ -137,7 +136,7 @@ def clone_and_edit(self, **kwargs) -> 'AttributeQuantizationConfig': Returns: AttributeQuantizationConfig: A new instance of AttributeQuantizationConfig with updated attributes. """ - return self.copy(update=kwargs) + return self.model_copy(update=kwargs) class OpQuantizationConfig(BaseModel): @@ -164,15 +163,14 @@ class OpQuantizationConfig(BaseModel): supported_input_activation_n_bits: Union[int, Tuple[int, ...]] enable_activation_quantization: bool quantization_preserving: bool - fixed_scale: Optional[float] - fixed_zero_point: Optional[int] - simd_size: Optional[int] signedness: Signedness + fixed_scale: Optional[float] = None + fixed_zero_point: Optional[int] = None + simd_size: Optional[int] = None - class Config: - frozen = True + model_config = ConfigDict(frozen=True) - @validator('supported_input_activation_n_bits', pre=True, allow_reuse=True) + @field_validator('supported_input_activation_n_bits', mode='before') def validate_supported_input_activation_n_bits(cls, v): """ Validate and process the supported_input_activation_n_bits field. @@ -199,9 +197,9 @@ def get_info(self) -> Dict[str, Any]: return self.dict() # pragma: no cover def clone_and_edit( - self, - attr_to_edit: Dict[str, Dict[str, Any]] = {}, - **kwargs: Any + self, + attr_to_edit: Dict[str, Dict[str, Any]] = {}, + **kwargs: Any ) -> 'OpQuantizationConfig': """ Clone the quantization config and edit some of its attributes. @@ -215,17 +213,17 @@ def clone_and_edit( OpQuantizationConfig: Edited quantization configuration. """ # Clone and update top-level attributes - updated_config = self.copy(update=kwargs) + updated_config = self.model_copy(update=kwargs) # Clone and update nested immutable dataclasses in `attr_weights_configs_mapping` updated_attr_mapping = { attr_name: (attr_cfg.clone_and_edit(**attr_to_edit[attr_name]) - if attr_name in attr_to_edit else attr_cfg) + if attr_name in attr_to_edit else attr_cfg) for attr_name, attr_cfg in updated_config.attr_weights_configs_mapping.items() } # Return a new instance with the updated attribute mapping - return updated_config.copy(update={'attr_weights_configs_mapping': updated_attr_mapping}) + return updated_config.model_copy(update={'attr_weights_configs_mapping': updated_attr_mapping}) class QuantizationConfigOptions(BaseModel): @@ -239,10 +237,9 @@ class QuantizationConfigOptions(BaseModel): quantization_configurations: Tuple[OpQuantizationConfig, ...] base_config: Optional[OpQuantizationConfig] = None - class Config: - frozen = True + model_config = ConfigDict(frozen=True) - @root_validator(pre=True, allow_reuse=True) + @model_validator(mode="before") def validate_and_set_base_config(cls, values: Dict[str, Any]) -> Dict[str, Any]: """ Validate and set the base_config based on quantization_configurations. @@ -282,12 +279,6 @@ def validate_and_set_base_config(cls, values: Dict[str, Any]) -> Dict[str, Any]: "'base_config' must be included in the quantization config options." ) # pragma: no cover - # if num_configs == 1: - # if base_config != quantization_configurations[0]: - # Logger.critical( - # "'base_config' should be the same as the sole item in 'quantization_configurations'." - # ) # pragma: no cover - values['base_config'] = base_config # When loading from JSON, lists are returned. If the value is a list, convert it to a tuple. @@ -312,7 +303,7 @@ def clone_and_edit(self, **kwargs) -> 'QuantizationConfigOptions': # Clone and update all configurations updated_configs = tuple(cfg.clone_and_edit(**kwargs) for cfg in self.quantization_configurations) - return self.copy(update={ + return self.model_copy(update={ 'base_config': updated_base_config, 'quantization_configurations': updated_configs }) @@ -360,7 +351,7 @@ def clone_and_edit_weight_attribute( updated_cfg = qc.clone_and_edit(attr_weights_configs_mapping=updated_attr_mapping) updated_configs.append(updated_cfg) - return self.copy(update={ + return self.model_copy(update={ 'base_config': updated_base_config, 'quantization_configurations': tuple(updated_configs) }) @@ -398,7 +389,7 @@ def clone_and_map_weights_attr_keys( updated_cfg = qc.clone_and_edit(attr_weights_configs_mapping=new_attr_mapping) updated_configs.append(updated_cfg) - return self.copy(update={ + return self.model_copy(update={ 'base_config': new_base_config, 'quantization_configurations': tuple(updated_configs) }) @@ -412,12 +403,12 @@ def get_info(self) -> Dict[str, Any]: """ return {f'option_{i}': cfg.get_info() for i, cfg in enumerate(self.quantization_configurations)} + class TargetPlatformModelComponent(BaseModel): """ Component of TargetPlatformCapabilities (Fusing, OperatorsSet, etc.). """ - class Config: - frozen = True + model_config = ConfigDict(frozen=True) class OperatorsSetBase(TargetPlatformModelComponent): @@ -444,8 +435,7 @@ class OperatorsSet(OperatorsSetBase): # Define a private attribute _type type: Literal["OperatorsSet"] = "OperatorsSet" - class Config: - frozen = True + model_config = ConfigDict(frozen=True) def get_info(self) -> Dict[str, Any]: """ @@ -471,10 +461,9 @@ class OperatorSetGroup(OperatorsSetBase): # Define a private attribute _type type: Literal["OperatorSetGroup"] = "OperatorSetGroup" - class Config: - frozen = True + model_config = ConfigDict(frozen=True) - @root_validator(pre=True, allow_reuse=True) + @model_validator(mode="before") def validate_and_set_name(cls, values: Dict[str, Any]) -> Dict[str, Any]: """ Validate the input and set the concatenated name based on the operators_set. @@ -512,6 +501,7 @@ def get_info(self) -> Dict[str, Any]: "operators_set": [op.get_info() for op in self.operators_set] } + class Fusing(TargetPlatformModelComponent): """ Fusing defines a tuple of operators that should be combined and treated as a single operator, @@ -525,10 +515,9 @@ class Fusing(TargetPlatformModelComponent): operator_groups: Tuple[Annotated[Union[OperatorsSet, OperatorSetGroup], Field(discriminator='type')], ...] name: Optional[str] = None # Will be set in the validator if not given. - class Config: - frozen = True + model_config = ConfigDict(frozen=True) - @root_validator(pre=True, allow_reuse=True) + @model_validator(mode="before") def validate_and_set_name(cls, values: Dict[str, Any]) -> Dict[str, Any]: """ Validate the operator_groups and set the name by concatenating operator group names. @@ -555,24 +544,15 @@ def validate_and_set_name(cls, values: Dict[str, Any]) -> Dict[str, Any]: return values - @root_validator(allow_reuse=True) - def validate_after_initialization(cls, values: Dict[str, Any]) -> Dict[str, Any]: + @model_validator(mode="after") + def validate_after_initialization(cls, model: 'Fusing') -> Any: """ Perform validation after the model has been instantiated. - - Args: - values (Dict[str, Any]): The instantiated fusing. - - Returns: - Dict[str, Any]: The validated values. + Ensures that there are at least two operator groups. """ - operator_groups = values.get('operator_groups') - - # Validate that there are at least two operator groups - if len(operator_groups) < 2: + if len(model.operator_groups) < 2: Logger.critical("Fusing cannot be created for a single operator.") # pragma: no cover - - return values + return model def contains(self, other: Any) -> bool: """ @@ -621,6 +601,7 @@ def get_info(self) -> Union[Dict[str, str], str]: for x in self.operator_groups ]) + class TargetPlatformCapabilities(BaseModel): """ Represents the hardware configuration used for quantized model inference. @@ -638,38 +619,37 @@ class TargetPlatformCapabilities(BaseModel): SCHEMA_VERSION (int): Version of the schema for the Target Platform Model. """ default_qco: QuantizationConfigOptions - operator_set: Optional[Tuple[OperatorsSet, ...]] - fusing_patterns: Optional[Tuple[Fusing, ...]] - tpc_minor_version: Optional[int] - tpc_patch_version: Optional[int] - tpc_platform_type: Optional[str] + operator_set: Optional[Tuple[OperatorsSet, ...]] = None + fusing_patterns: Optional[Tuple[Fusing, ...]] = None + tpc_minor_version: Optional[int] = None + tpc_patch_version: Optional[int] = None + tpc_platform_type: Optional[str] = None add_metadata: bool = True name: Optional[str] = "default_tpc" is_simd_padding: bool = False SCHEMA_VERSION: int = 1 - class Config: - frozen = True + model_config = ConfigDict(frozen=True) - @root_validator(allow_reuse=True) - def validate_after_initialization(cls, values: Dict[str, Any]) -> Dict[str, Any]: + @model_validator(mode="after") + def validate_after_initialization(cls, model: 'TargetPlatformCapabilities') -> Any: """ Perform validation after the model has been instantiated. Args: - values (Dict[str, Any]): The instantiated target platform model. + model (TargetPlatformCapabilities): The instantiated target platform model. Returns: - Dict[str, Any]: The validated values. + TargetPlatformCapabilities: The validated model. """ # Validate `default_qco` - default_qco = values.get('default_qco') + default_qco = model.default_qco if len(default_qco.quantization_configurations) != 1: Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") # pragma: no cover # Validate `operator_set` uniqueness - operator_set = values.get('operator_set') + operator_set = model.operator_set if operator_set is not None: opsets_names = [ op.name.value if isinstance(op.name, OperatorSetNames) else op.name @@ -678,7 +658,7 @@ def validate_after_initialization(cls, values: Dict[str, Any]) -> Dict[str, Any] if len(set(opsets_names)) != len(opsets_names): Logger.critical("Operator Sets must have unique names.") # pragma: no cover - return values + return model def get_info(self) -> Dict[str, Any]: """ diff --git a/model_compression_toolkit/target_platform_capabilities/schema/v2.py b/model_compression_toolkit/target_platform_capabilities/schema/v2.py index 475d4379c..bad8cf6b4 100644 --- a/model_compression_toolkit/target_platform_capabilities/schema/v2.py +++ b/model_compression_toolkit/target_platform_capabilities/schema/v2.py @@ -13,10 +13,9 @@ # limitations under the License. # ============================================================================== import pprint -from enum import Enum from typing import Dict, Any, Tuple, Optional -from pydantic import BaseModel, root_validator +from pydantic import BaseModel, root_validator, model_validator, ConfigDict from mct_quantizers import QuantizationMethod from model_compression_toolkit.constants import FLOAT_BITWIDTH @@ -62,27 +61,26 @@ class TargetPlatformCapabilities(BaseModel): SCHEMA_VERSION: int = 2 - class Config: - frozen = True + model_config = ConfigDict(frozen=True) - @root_validator(allow_reuse=True) - def validate_after_initialization(cls, values: Dict[str, Any]) -> Dict[str, Any]: + @model_validator(mode="after") + def validate_after_initialization(cls, model: 'TargetPlatformCapabilities') -> Any: """ Perform validation after the model has been instantiated. Args: - values (Dict[str, Any]): The instantiated target platform model. + model (TargetPlatformCapabilities): The instantiated target platform model. Returns: - Dict[str, Any]: The validated values. + TargetPlatformCapabilities: The validated model. """ # Validate `default_qco` - default_qco = values.get('default_qco') + default_qco = model.default_qco if len(default_qco.quantization_configurations) != 1: Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") # pragma: no cover # Validate `operator_set` uniqueness - operator_set = values.get('operator_set') + operator_set = model.operator_set if operator_set is not None: opsets_names = [ op.name.value if isinstance(op.name, OperatorSetNames) else op.name @@ -91,7 +89,7 @@ def validate_after_initialization(cls, values: Dict[str, Any]) -> Dict[str, Any] if len(set(opsets_names)) != len(opsets_names): Logger.critical("Operator Sets must have unique names.") # pragma: no cover - return values + return model def get_info(self) -> Dict[str, Any]: """ diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py b/model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py index 653192c0e..d40905627 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py @@ -100,6 +100,6 @@ def export_target_platform_capabilities(model: schema.TargetPlatformCapabilities # Export the model to JSON and write to the file with path.open('w', encoding='utf-8') as file: - file.write(model.json(indent=4)) + file.write(model.model_dump_json(indent=4)) except OSError as e: raise OSError(f"Failed to write to file '{export_path}': {e.strerror}") from e \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 1ea53aadf..8f1ee1298 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,5 +10,6 @@ matplotlib<3.10.0 scipy protobuf mct-quantizers-nightly -pydantic<2.0 +pydantic>=2.0 sony-custom-layers-dev==0.4.0.dev6 + diff --git a/tests_pytest/common_tests/unit_tests/core/quantization/node_quantization_config/test_node_weights_quantization_config.py b/tests_pytest/common_tests/unit_tests/core/quantization/node_quantization_config/test_node_weights_quantization_config.py index 9c6fb660e..13a0cb0a5 100644 --- a/tests_pytest/common_tests/unit_tests/core/quantization/node_quantization_config/test_node_weights_quantization_config.py +++ b/tests_pytest/common_tests/unit_tests/core/quantization/node_quantization_config/test_node_weights_quantization_config.py @@ -115,7 +115,7 @@ def test_node_weights_quantization_config_op_cfg_mapping(self): # Test using the positional attribute as the key rather than POS_ATTR; this mismatch should cause # NodeWeightsQuantizationConfig to fall back to the default weights attribute configuration instead of # applying the specific one. - op_cfg = self._create_node_weights_op_cfg(pos_weight_attr=[positional_weight_attr], + op_cfg = self._create_node_weights_op_cfg(pos_weight_attr=[str(positional_weight_attr)], pos_weight_attr_config=[pos_weight_attr_config], def_weight_attr_config=def_weight_attr_config) diff --git a/tests_pytest/common_tests/unit_tests/target_platform_capabilities/test_tpc.py b/tests_pytest/common_tests/unit_tests/target_platform_capabilities/test_tpc.py index 1d69e29f1..03e026b32 100644 --- a/tests_pytest/common_tests/unit_tests/target_platform_capabilities/test_tpc.py +++ b/tests_pytest/common_tests/unit_tests/target_platform_capabilities/test_tpc.py @@ -14,6 +14,7 @@ # ============================================================================== import os import pytest +from pydantic import ValidationError import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as current_schema from model_compression_toolkit.core.common import BaseNode from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR @@ -165,7 +166,7 @@ def test_valid_export(self, tpc, valid_export_path): assert os.path.exists(str(valid_export_path)) with open(str(valid_export_path), "r", encoding="utf-8") as file: content = file.read() - assert content == tpc.json(indent=4) + assert content == tpc.model_dump_json(indent=4) def test_export_with_invalid_model(self, valid_export_path): """Tests that exporting an invalid model raises a ValueError.""" @@ -178,13 +179,13 @@ def test_export_with_invalid_path(self, tpc, invalid_export_path): export_target_platform_capabilities(tpc, str(invalid_export_path)) def test_export_creates_parent_directories(self, tpc, tmp_path): - """Tests that exporting to an invalid path raises an OSError.""" + """Tests that exporting creates parent directories as needed.""" nested_path = tmp_path / "nested" / "directory" / "exported_model.json" export_target_platform_capabilities(tpc, str(nested_path)) assert os.path.exists(str(nested_path)) with open(str(nested_path), "r", encoding="utf-8") as file: content = file.read() - assert content == tpc.json(indent=4) + assert content == tpc.model_dump_json(indent=4) # Cleanup created directories os.remove(str(nested_path)) os.rmdir(str(tmp_path / "nested" / "directory")) @@ -200,17 +201,19 @@ def test_export_then_import(self, tpc, valid_export_path): class TestTargetPlatformModeling: def test_immutable_tp(self): """Tests that modifying an immutable TargetPlatformCapabilities instance raises an exception.""" - with pytest.raises(Exception, match='"TargetPlatformCapabilities" is immutable and does not support item assignment'): - model = current_schema.TargetPlatformCapabilities( - default_qco=TEST_QCO, + model = current_schema.TargetPlatformCapabilities( + default_qco=TEST_QCO, operator_set=(current_schema.OperatorsSet(name="opset"),), - tpc_minor_version=None, - tpc_patch_version=None, - tpc_platform_type=None, - add_metadata=False - ) + tpc_minor_version=None, + tpc_patch_version=None, + tpc_platform_type=None, + add_metadata=False + ) + # Expecting a TypeError or AttributeError due to immutability + with pytest.raises(ValidationError , match="Instance is frozen"): model.operator_set = tuple() + def test_default_options_more_than_single_qc(self): """Tests that creating a TargetPlatformCapabilities with default_qco containing more than one configuration raises an exception.""" test_qco = current_schema.QuantizationConfigOptions( @@ -304,7 +307,7 @@ def test_empty_qc_options(self): def test_list_of_no_qc(self): """Tests that providing an invalid configuration list (non-dict values) to QuantizationConfigOptions raises an exception.""" - with pytest.raises(Exception, match="value is not a valid dict"): + with pytest.raises(ValidationError, match="Input should be a valid dictionary"): current_schema.QuantizationConfigOptions(quantization_configurations=(TEST_QC, 3), base_config=TEST_QC) def test_clone_and_edit_options(self):