From d59d99c13032447b00e4386b4759475eea5a61f1 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 1 Apr 2025 16:23:48 +0200 Subject: [PATCH 1/7] Remove `pydantic<2.0` pin in requirements.txt Request to unpin this dependency as this is causing installation conflicts in the Ultralytics package when used with other export dependencies that require more recent versions of Pydantic. I think the main changes may be needed in `model_compression_toolkit` here: --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 1b14dbbc4..085474615 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,4 @@ matplotlib<3.10.0 scipy protobuf mct-quantizers==1.5.2 -pydantic<2.0 \ No newline at end of file +pydantic From 612593e5434c04254439705b2226beb027a81b30 Mon Sep 17 00:00:00 2001 From: liord Date: Tue, 8 Apr 2025 14:24:33 +0300 Subject: [PATCH 2/7] Add support for pydantic v2 --- .../target_platform_capabilities/schema/v1.py | 114 ++++++++---------- .../tpc_io_handler.py | 4 +- .../target_platform_capabilities/test_tpc.py | 30 +++-- 3 files changed, 66 insertions(+), 82 deletions(-) diff --git a/model_compression_toolkit/target_platform_capabilities/schema/v1.py b/model_compression_toolkit/target_platform_capabilities/schema/v1.py index fbbf5a2fb..284b56cd2 100644 --- a/model_compression_toolkit/target_platform_capabilities/schema/v1.py +++ b/model_compression_toolkit/target_platform_capabilities/schema/v1.py @@ -16,7 +16,8 @@ from enum import Enum from typing import Dict, Any, Union, Tuple, List, Optional, Literal, Annotated -from pydantic import BaseModel, Field, root_validator, validator, PositiveInt +from pydantic import BaseModel, Field, root_validator, validator, PositiveInt, ConfigDict, field_validator, \ + model_validator from mct_quantizers import QuantizationMethod from model_compression_toolkit.constants import FLOAT_BITWIDTH @@ -118,9 +119,7 @@ class AttributeQuantizationConfig(BaseModel): enable_weights_quantization: bool = False lut_values_bitwidth: Optional[int] = None - class Config: - # Makes the model immutable (frozen) - frozen = True + model_config = ConfigDict(frozen=True) @property def field_names(self) -> list: @@ -137,7 +136,7 @@ def clone_and_edit(self, **kwargs) -> 'AttributeQuantizationConfig': Returns: AttributeQuantizationConfig: A new instance of AttributeQuantizationConfig with updated attributes. """ - return self.copy(update=kwargs) + return self.model_copy(update=kwargs) class OpQuantizationConfig(BaseModel): @@ -164,15 +163,14 @@ class OpQuantizationConfig(BaseModel): supported_input_activation_n_bits: Union[int, Tuple[int, ...]] enable_activation_quantization: bool quantization_preserving: bool - fixed_scale: Optional[float] - fixed_zero_point: Optional[int] - simd_size: Optional[int] signedness: Signedness + fixed_scale: Optional[float] = None + fixed_zero_point: Optional[int] = None + simd_size: Optional[int] = None - class Config: - frozen = True + model_config = ConfigDict(frozen=True) - @validator('supported_input_activation_n_bits', pre=True, allow_reuse=True) + @field_validator('supported_input_activation_n_bits', mode='before') def validate_supported_input_activation_n_bits(cls, v): """ Validate and process the supported_input_activation_n_bits field. @@ -199,9 +197,9 @@ def get_info(self) -> Dict[str, Any]: return self.dict() # pragma: no cover def clone_and_edit( - self, - attr_to_edit: Dict[str, Dict[str, Any]] = {}, - **kwargs: Any + self, + attr_to_edit: Dict[str, Dict[str, Any]] = {}, + **kwargs: Any ) -> 'OpQuantizationConfig': """ Clone the quantization config and edit some of its attributes. @@ -215,17 +213,17 @@ def clone_and_edit( OpQuantizationConfig: Edited quantization configuration. """ # Clone and update top-level attributes - updated_config = self.copy(update=kwargs) + updated_config = self.model_copy(update=kwargs) # Clone and update nested immutable dataclasses in `attr_weights_configs_mapping` updated_attr_mapping = { attr_name: (attr_cfg.clone_and_edit(**attr_to_edit[attr_name]) - if attr_name in attr_to_edit else attr_cfg) + if attr_name in attr_to_edit else attr_cfg) for attr_name, attr_cfg in updated_config.attr_weights_configs_mapping.items() } # Return a new instance with the updated attribute mapping - return updated_config.copy(update={'attr_weights_configs_mapping': updated_attr_mapping}) + return updated_config.model_copy(update={'attr_weights_configs_mapping': updated_attr_mapping}) class QuantizationConfigOptions(BaseModel): @@ -239,10 +237,9 @@ class QuantizationConfigOptions(BaseModel): quantization_configurations: Tuple[OpQuantizationConfig, ...] base_config: Optional[OpQuantizationConfig] = None - class Config: - frozen = True + model_config = ConfigDict(frozen=True) - @root_validator(pre=True, allow_reuse=True) + @model_validator(mode="before") def validate_and_set_base_config(cls, values: Dict[str, Any]) -> Dict[str, Any]: """ Validate and set the base_config based on quantization_configurations. @@ -282,12 +279,6 @@ def validate_and_set_base_config(cls, values: Dict[str, Any]) -> Dict[str, Any]: "'base_config' must be included in the quantization config options." ) # pragma: no cover - # if num_configs == 1: - # if base_config != quantization_configurations[0]: - # Logger.critical( - # "'base_config' should be the same as the sole item in 'quantization_configurations'." - # ) # pragma: no cover - values['base_config'] = base_config # When loading from JSON, lists are returned. If the value is a list, convert it to a tuple. @@ -312,7 +303,7 @@ def clone_and_edit(self, **kwargs) -> 'QuantizationConfigOptions': # Clone and update all configurations updated_configs = tuple(cfg.clone_and_edit(**kwargs) for cfg in self.quantization_configurations) - return self.copy(update={ + return self.model_copy(update={ 'base_config': updated_base_config, 'quantization_configurations': updated_configs }) @@ -360,7 +351,7 @@ def clone_and_edit_weight_attribute( updated_cfg = qc.clone_and_edit(attr_weights_configs_mapping=updated_attr_mapping) updated_configs.append(updated_cfg) - return self.copy(update={ + return self.model_copy(update={ 'base_config': updated_base_config, 'quantization_configurations': tuple(updated_configs) }) @@ -398,7 +389,7 @@ def clone_and_map_weights_attr_keys( updated_cfg = qc.clone_and_edit(attr_weights_configs_mapping=new_attr_mapping) updated_configs.append(updated_cfg) - return self.copy(update={ + return self.model_copy(update={ 'base_config': new_base_config, 'quantization_configurations': tuple(updated_configs) }) @@ -412,12 +403,12 @@ def get_info(self) -> Dict[str, Any]: """ return {f'option_{i}': cfg.get_info() for i, cfg in enumerate(self.quantization_configurations)} + class TargetPlatformModelComponent(BaseModel): """ Component of TargetPlatformCapabilities (Fusing, OperatorsSet, etc.). """ - class Config: - frozen = True + model_config = ConfigDict(frozen=True) class OperatorsSetBase(TargetPlatformModelComponent): @@ -444,8 +435,7 @@ class OperatorsSet(OperatorsSetBase): # Define a private attribute _type type: Literal["OperatorsSet"] = "OperatorsSet" - class Config: - frozen = True + model_config = ConfigDict(frozen=True) def get_info(self) -> Dict[str, Any]: """ @@ -471,10 +461,9 @@ class OperatorSetGroup(OperatorsSetBase): # Define a private attribute _type type: Literal["OperatorSetGroup"] = "OperatorSetGroup" - class Config: - frozen = True + model_config = ConfigDict(frozen=True) - @root_validator(pre=True, allow_reuse=True) + @model_validator(mode="before") def validate_and_set_name(cls, values: Dict[str, Any]) -> Dict[str, Any]: """ Validate the input and set the concatenated name based on the operators_set. @@ -512,6 +501,7 @@ def get_info(self) -> Dict[str, Any]: "operators_set": [op.get_info() for op in self.operators_set] } + class Fusing(TargetPlatformModelComponent): """ Fusing defines a tuple of operators that should be combined and treated as a single operator, @@ -525,10 +515,9 @@ class Fusing(TargetPlatformModelComponent): operator_groups: Tuple[Annotated[Union[OperatorsSet, OperatorSetGroup], Field(discriminator='type')], ...] name: Optional[str] = None # Will be set in the validator if not given. - class Config: - frozen = True + model_config = ConfigDict(frozen=True) - @root_validator(pre=True, allow_reuse=True) + @model_validator(mode="before") def validate_and_set_name(cls, values: Dict[str, Any]) -> Dict[str, Any]: """ Validate the operator_groups and set the name by concatenating operator group names. @@ -555,24 +544,15 @@ def validate_and_set_name(cls, values: Dict[str, Any]) -> Dict[str, Any]: return values - @root_validator(allow_reuse=True) - def validate_after_initialization(cls, values: Dict[str, Any]) -> Dict[str, Any]: + @model_validator(mode="after") + def validate_after_initialization(cls, model: 'Fusing') -> Any: """ Perform validation after the model has been instantiated. - - Args: - values (Dict[str, Any]): The instantiated fusing. - - Returns: - Dict[str, Any]: The validated values. + Ensures that there are at least two operator groups. """ - operator_groups = values.get('operator_groups') - - # Validate that there are at least two operator groups - if len(operator_groups) < 2: + if len(model.operator_groups) < 2: Logger.critical("Fusing cannot be created for a single operator.") # pragma: no cover - - return values + return model def contains(self, other: Any) -> bool: """ @@ -621,6 +601,7 @@ def get_info(self) -> Union[Dict[str, str], str]: for x in self.operator_groups ]) + class TargetPlatformCapabilities(BaseModel): """ Represents the hardware configuration used for quantized model inference. @@ -638,38 +619,37 @@ class TargetPlatformCapabilities(BaseModel): SCHEMA_VERSION (int): Version of the schema for the Target Platform Model. """ default_qco: QuantizationConfigOptions - operator_set: Optional[Tuple[OperatorsSet, ...]] - fusing_patterns: Optional[Tuple[Fusing, ...]] - tpc_minor_version: Optional[int] - tpc_patch_version: Optional[int] - tpc_platform_type: Optional[str] + operator_set: Optional[Tuple[OperatorsSet, ...]] = None + fusing_patterns: Optional[Tuple[Fusing, ...]] = None + tpc_minor_version: Optional[int] = None + tpc_patch_version: Optional[int] = None + tpc_platform_type: Optional[str] = None add_metadata: bool = True name: Optional[str] = "default_tpc" is_simd_padding: bool = False SCHEMA_VERSION: int = 1 - class Config: - frozen = True + model_config = ConfigDict(frozen=True) - @root_validator(allow_reuse=True) - def validate_after_initialization(cls, values: Dict[str, Any]) -> Dict[str, Any]: + @model_validator(mode="after") + def validate_after_initialization(cls, model: 'TargetPlatformCapabilities') -> Any: """ Perform validation after the model has been instantiated. Args: - values (Dict[str, Any]): The instantiated target platform model. + model (TargetPlatformCapabilities): The instantiated target platform model. Returns: - Dict[str, Any]: The validated values. + TargetPlatformCapabilities: The validated model. """ # Validate `default_qco` - default_qco = values.get('default_qco') + default_qco = model.default_qco if len(default_qco.quantization_configurations) != 1: Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") # pragma: no cover # Validate `operator_set` uniqueness - operator_set = values.get('operator_set') + operator_set = model.operator_set if operator_set is not None: opsets_names = [ op.name.value if isinstance(op.name, OperatorSetNames) else op.name @@ -678,7 +658,7 @@ def validate_after_initialization(cls, values: Dict[str, Any]) -> Dict[str, Any] if len(set(opsets_names)) != len(opsets_names): Logger.critical("Operator Sets must have unique names.") # pragma: no cover - return values + return model def get_info(self) -> Dict[str, Any]: """ diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py b/model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py index 49ad7a28c..5c9ac7f34 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py @@ -54,7 +54,7 @@ def load_target_platform_capabilities(tpc_obj_or_path: Union[TargetPlatformCapab raise ValueError(f"Error reading the file '{tpc_obj_or_path}': {e.strerror}.") from e try: - return TargetPlatformCapabilities.parse_raw(data) + return TargetPlatformCapabilities.model_validate_json(data) except ValueError as e: raise ValueError(f"Invalid JSON for loading TargetPlatformCapabilities in '{tpc_obj_or_path}': {e}.") from e except Exception as e: @@ -88,6 +88,6 @@ def export_target_platform_capabilities(model: TargetPlatformCapabilities, expor # Export the model to JSON and write to the file with path.open('w', encoding='utf-8') as file: - file.write(model.json(indent=4)) + file.write(model.model_dump_json(indent=4)) except OSError as e: raise OSError(f"Failed to write to file '{export_path}': {e.strerror}") from e \ No newline at end of file diff --git a/tests_pytest/common_tests/unit_tests/target_platform_capabilities/test_tpc.py b/tests_pytest/common_tests/unit_tests/target_platform_capabilities/test_tpc.py index 03ce2d6f1..9e284afd5 100644 --- a/tests_pytest/common_tests/unit_tests/target_platform_capabilities/test_tpc.py +++ b/tests_pytest/common_tests/unit_tests/target_platform_capabilities/test_tpc.py @@ -14,6 +14,8 @@ # ============================================================================== import os import pytest +from pydantic import ValidationError + import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema from model_compression_toolkit.core.common import BaseNode from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR @@ -117,7 +119,7 @@ def test_valid_export(self, tpc, valid_export_path): assert os.path.exists(str(valid_export_path)) with open(str(valid_export_path), "r", encoding="utf-8") as file: content = file.read() - assert content == tpc.json(indent=4) + assert content == tpc.model_dump_json(indent=4) def test_export_with_invalid_model(self, valid_export_path): """Tests that exporting an invalid model raises a ValueError.""" @@ -130,13 +132,13 @@ def test_export_with_invalid_path(self, tpc, invalid_export_path): export_target_platform_capabilities(tpc, str(invalid_export_path)) def test_export_creates_parent_directories(self, tpc, tmp_path): - """Tests that exporting to an invalid path raises an OSError.""" + """Tests that exporting creates parent directories as needed.""" nested_path = tmp_path / "nested" / "directory" / "exported_model.json" export_target_platform_capabilities(tpc, str(nested_path)) assert os.path.exists(str(nested_path)) with open(str(nested_path), "r", encoding="utf-8") as file: content = file.read() - assert content == tpc.json(indent=4) + assert content == tpc.model_dump_json(indent=4) # Cleanup created directories os.remove(str(nested_path)) os.rmdir(str(tmp_path / "nested" / "directory")) @@ -152,17 +154,19 @@ def test_export_then_import(self, tpc, valid_export_path): class TestTargetPlatformModeling: def test_immutable_tp(self): """Tests that modifying an immutable TargetPlatformCapabilities instance raises an exception.""" - with pytest.raises(Exception, match='"TargetPlatformCapabilities" is immutable and does not support item assignment'): - model = schema.TargetPlatformCapabilities( - default_qco=TEST_QCO, - operator_set=(schema.OperatorsSet(name="opset"),), - tpc_minor_version=None, - tpc_patch_version=None, - tpc_platform_type=None, - add_metadata=False - ) + model = schema.TargetPlatformCapabilities( + default_qco=TEST_QCO, + operator_set=(schema.OperatorsSet(name="opset"),), + tpc_minor_version=None, + tpc_patch_version=None, + tpc_platform_type=None, + add_metadata=False + ) + # Expecting a TypeError or AttributeError due to immutability + with pytest.raises(ValidationError , match="Instance is frozen"): model.operator_set = tuple() + def test_default_options_more_than_single_qc(self): """Tests that creating a TargetPlatformCapabilities with default_qco containing more than one configuration raises an exception.""" test_qco = schema.QuantizationConfigOptions( @@ -256,7 +260,7 @@ def test_empty_qc_options(self): def test_list_of_no_qc(self): """Tests that providing an invalid configuration list (non-dict values) to QuantizationConfigOptions raises an exception.""" - with pytest.raises(Exception, match="value is not a valid dict"): + with pytest.raises(ValidationError, match="Input should be a valid dictionary"): schema.QuantizationConfigOptions(quantization_configurations=(TEST_QC, 3), base_config=TEST_QC) def test_clone_and_edit_options(self): From a824569937f4b7483413d25add629cf7558905ed Mon Sep 17 00:00:00 2001 From: liord Date: Wed, 9 Apr 2025 10:36:15 +0300 Subject: [PATCH 3/7] Fix test that fails due to upgrade to pydantic v2 --- .../test_node_weights_quantization_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests_pytest/common_tests/unit_tests/core/quantization/node_quantization_config/test_node_weights_quantization_config.py b/tests_pytest/common_tests/unit_tests/core/quantization/node_quantization_config/test_node_weights_quantization_config.py index 9c6fb660e..13a0cb0a5 100644 --- a/tests_pytest/common_tests/unit_tests/core/quantization/node_quantization_config/test_node_weights_quantization_config.py +++ b/tests_pytest/common_tests/unit_tests/core/quantization/node_quantization_config/test_node_weights_quantization_config.py @@ -115,7 +115,7 @@ def test_node_weights_quantization_config_op_cfg_mapping(self): # Test using the positional attribute as the key rather than POS_ATTR; this mismatch should cause # NodeWeightsQuantizationConfig to fall back to the default weights attribute configuration instead of # applying the specific one. - op_cfg = self._create_node_weights_op_cfg(pos_weight_attr=[positional_weight_attr], + op_cfg = self._create_node_weights_op_cfg(pos_weight_attr=[str(positional_weight_attr)], pos_weight_attr_config=[pos_weight_attr_config], def_weight_attr_config=def_weight_attr_config) From 00bba0023078b903a972ac8a56d843884945557e Mon Sep 17 00:00:00 2001 From: liord Date: Wed, 9 Apr 2025 14:54:12 +0300 Subject: [PATCH 4/7] Remove tests for keras\tensorflow < 2.14. Update requirements.txt --- .../run_tests_python310_keras212.yml | 19 ------------------- .../run_tests_python310_keras213.yml | 19 ------------------- .../run_tests_python311_keras212.yml | 19 ------------------- .../run_tests_python311_keras213.yml | 19 ------------------- .../workflows/run_tests_python39_keras212.yml | 19 ------------------- .../workflows/run_tests_python39_keras213.yml | 19 ------------------- .../workflows/run_tests_suite_coverage.yml | 2 +- README.md | 2 +- requirements.txt | 2 +- 9 files changed, 3 insertions(+), 117 deletions(-) delete mode 100644 .github/workflows/run_tests_python310_keras212.yml delete mode 100644 .github/workflows/run_tests_python310_keras213.yml delete mode 100644 .github/workflows/run_tests_python311_keras212.yml delete mode 100644 .github/workflows/run_tests_python311_keras213.yml delete mode 100644 .github/workflows/run_tests_python39_keras212.yml delete mode 100644 .github/workflows/run_tests_python39_keras213.yml diff --git a/.github/workflows/run_tests_python310_keras212.yml b/.github/workflows/run_tests_python310_keras212.yml deleted file mode 100644 index 172faaa31..000000000 --- a/.github/workflows/run_tests_python310_keras212.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Python 3.10, Keras 2.12 -on: - workflow_dispatch: # Allow manual triggers - schedule: - - cron: 0 0 * * * - pull_request: - branches: - - main - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} - -jobs: - run-tests: - uses: ./.github/workflows/run_keras_tests.yml - with: - python-version: "3.10" - tf-version: "2.12.*" \ No newline at end of file diff --git a/.github/workflows/run_tests_python310_keras213.yml b/.github/workflows/run_tests_python310_keras213.yml deleted file mode 100644 index 7a90da443..000000000 --- a/.github/workflows/run_tests_python310_keras213.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Python 3.10, Keras 2.13 -on: - workflow_dispatch: # Allow manual triggers - schedule: - - cron: 0 0 * * * - pull_request: - branches: - - main - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} - -jobs: - run-tests: - uses: ./.github/workflows/run_keras_tests.yml - with: - python-version: "3.10" - tf-version: "2.13.*" \ No newline at end of file diff --git a/.github/workflows/run_tests_python311_keras212.yml b/.github/workflows/run_tests_python311_keras212.yml deleted file mode 100644 index b129a5843..000000000 --- a/.github/workflows/run_tests_python311_keras212.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Python 3.11, Keras 2.12 -on: - workflow_dispatch: # Allow manual triggers - schedule: - - cron: 0 0 * * * - pull_request: - branches: - - main - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} - -jobs: - run-tests: - uses: ./.github/workflows/run_keras_tests.yml - with: - python-version: "3.11" - tf-version: "2.12.*" \ No newline at end of file diff --git a/.github/workflows/run_tests_python311_keras213.yml b/.github/workflows/run_tests_python311_keras213.yml deleted file mode 100644 index 635534bfc..000000000 --- a/.github/workflows/run_tests_python311_keras213.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Python 3.11, Keras 2.13 -on: - workflow_dispatch: # Allow manual triggers - schedule: - - cron: 0 0 * * * - pull_request: - branches: - - main - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} - -jobs: - run-tests: - uses: ./.github/workflows/run_keras_tests.yml - with: - python-version: "3.11" - tf-version: "2.13.*" \ No newline at end of file diff --git a/.github/workflows/run_tests_python39_keras212.yml b/.github/workflows/run_tests_python39_keras212.yml deleted file mode 100644 index 5f42dd35b..000000000 --- a/.github/workflows/run_tests_python39_keras212.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Python 3.9, Keras 2.12 -on: - workflow_dispatch: # Allow manual triggers - schedule: - - cron: 0 0 * * * - pull_request: - branches: - - main - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} - -jobs: - run-tests: - uses: ./.github/workflows/run_keras_tests.yml - with: - python-version: "3.9" - tf-version: "2.12.*" \ No newline at end of file diff --git a/.github/workflows/run_tests_python39_keras213.yml b/.github/workflows/run_tests_python39_keras213.yml deleted file mode 100644 index fa7c34390..000000000 --- a/.github/workflows/run_tests_python39_keras213.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Python 3.9, Keras 2.13 -on: - workflow_dispatch: # Allow manual triggers - schedule: - - cron: 0 0 * * * - pull_request: - branches: - - main - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} - -jobs: - run-tests: - uses: ./.github/workflows/run_keras_tests.yml - with: - python-version: "3.9" - tf-version: "2.13.*" \ No newline at end of file diff --git a/.github/workflows/run_tests_suite_coverage.yml b/.github/workflows/run_tests_suite_coverage.yml index 2b23fbef3..6b26e3c6e 100644 --- a/.github/workflows/run_tests_suite_coverage.yml +++ b/.github/workflows/run_tests_suite_coverage.yml @@ -47,7 +47,7 @@ jobs: python -m venv tf_env source tf_env/bin/activate python -m pip install --upgrade pip - pip install -r requirements.txt tensorflow==2.13.* sony-custom-layers coverage pytest pytest-mock + pip install -r requirements.txt tensorflow==2.15.* sony-custom-layers coverage pytest pytest-mock - name: Run TensorFlow tests (unittest) run: | diff --git a/README.md b/README.md index 40cbd3ae5..3f1062792 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ ________________________________________________________________________________ ##
Getting Started
### Quick Installation -Pip install the model compression toolkit package in a Python>=3.9 environment with PyTorch>=2.1 or Tensorflow>=2.12. +Pip install the model compression toolkit package in a Python>=3.9 environment with PyTorch>=2.1 or Tensorflow>=2.14. ``` pip install model-compression-toolkit ``` diff --git a/requirements.txt b/requirements.txt index bbac7cfd4..302a6005d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,4 @@ matplotlib<3.10.0 scipy protobuf mct-quantizers-nightly -pydantic \ No newline at end of file +pydantic>=2.0 \ No newline at end of file From 4ef531a7721a97761e89cebcc13f62bd8b3c25b3 Mon Sep 17 00:00:00 2001 From: liord Date: Thu, 10 Apr 2025 14:23:40 +0300 Subject: [PATCH 5/7] update schema v2 with pydantic v2 --- .../target_platform_capabilities/schema/v2.py | 28 ++++++------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/model_compression_toolkit/target_platform_capabilities/schema/v2.py b/model_compression_toolkit/target_platform_capabilities/schema/v2.py index 475d4379c..796f67427 100644 --- a/model_compression_toolkit/target_platform_capabilities/schema/v2.py +++ b/model_compression_toolkit/target_platform_capabilities/schema/v2.py @@ -13,23 +13,14 @@ # limitations under the License. # ============================================================================== import pprint -from enum import Enum from typing import Dict, Any, Tuple, Optional -from pydantic import BaseModel, root_validator +from pydantic import BaseModel, model_validator, ConfigDict -from mct_quantizers import QuantizationMethod -from model_compression_toolkit.constants import FLOAT_BITWIDTH from model_compression_toolkit.logger import Logger from model_compression_toolkit.target_platform_capabilities.schema.v1 import ( - Signedness, - AttributeQuantizationConfig, - OpQuantizationConfig, QuantizationConfigOptions, - TargetPlatformModelComponent, - OperatorsSetBase, OperatorsSet, - OperatorSetGroup, Fusing, OperatorSetNames) @@ -62,27 +53,26 @@ class TargetPlatformCapabilities(BaseModel): SCHEMA_VERSION: int = 2 - class Config: - frozen = True + model_config = ConfigDict(frozen=True) - @root_validator(allow_reuse=True) - def validate_after_initialization(cls, values: Dict[str, Any]) -> Dict[str, Any]: + @model_validator(mode="after") + def validate_after_initialization(cls, model: 'TargetPlatformCapabilities') -> Any: """ Perform validation after the model has been instantiated. Args: - values (Dict[str, Any]): The instantiated target platform model. + model (TargetPlatformCapabilities): The instantiated target platform model. Returns: - Dict[str, Any]: The validated values. + TargetPlatformCapabilities: The validated model. """ # Validate `default_qco` - default_qco = values.get('default_qco') + default_qco = model.default_qco if len(default_qco.quantization_configurations) != 1: Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") # pragma: no cover # Validate `operator_set` uniqueness - operator_set = values.get('operator_set') + operator_set = model.operator_set if operator_set is not None: opsets_names = [ op.name.value if isinstance(op.name, OperatorSetNames) else op.name @@ -91,7 +81,7 @@ def validate_after_initialization(cls, values: Dict[str, Any]) -> Dict[str, Any] if len(set(opsets_names)) != len(opsets_names): Logger.critical("Operator Sets must have unique names.") # pragma: no cover - return values + return model def get_info(self) -> Dict[str, Any]: """ From 1703e3af655a0c2f4628e047d19a7993357eb306 Mon Sep 17 00:00:00 2001 From: liord Date: Thu, 10 Apr 2025 14:28:57 +0300 Subject: [PATCH 6/7] update schema v2 with pydantic v2 --- .../target_platform_capabilities/schema/v2.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/model_compression_toolkit/target_platform_capabilities/schema/v2.py b/model_compression_toolkit/target_platform_capabilities/schema/v2.py index 796f67427..bad8cf6b4 100644 --- a/model_compression_toolkit/target_platform_capabilities/schema/v2.py +++ b/model_compression_toolkit/target_platform_capabilities/schema/v2.py @@ -15,12 +15,20 @@ import pprint from typing import Dict, Any, Tuple, Optional -from pydantic import BaseModel, model_validator, ConfigDict +from pydantic import BaseModel, root_validator, model_validator, ConfigDict +from mct_quantizers import QuantizationMethod +from model_compression_toolkit.constants import FLOAT_BITWIDTH from model_compression_toolkit.logger import Logger from model_compression_toolkit.target_platform_capabilities.schema.v1 import ( + Signedness, + AttributeQuantizationConfig, + OpQuantizationConfig, QuantizationConfigOptions, + TargetPlatformModelComponent, + OperatorsSetBase, OperatorsSet, + OperatorSetGroup, Fusing, OperatorSetNames) From e717616f175226eaff5b9cc7d41d209f8387f26f Mon Sep 17 00:00:00 2001 From: liord Date: Mon, 14 Apr 2025 14:23:41 +0300 Subject: [PATCH 7/7] Suppress specific warning regarding pydantic v2 in docsrc creation. --- docsrc/source/conf.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/docsrc/source/conf.py b/docsrc/source/conf.py index c375580de..e2eff2bf4 100644 --- a/docsrc/source/conf.py +++ b/docsrc/source/conf.py @@ -10,6 +10,20 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # +import logging +import re + +# Define a filter that suppresses log records matching our unwanted warning message. +class IgnoreGuardedTypeImportFilter(logging.Filter): + def filter(self, record): + # Return False (i.e. ignore) if the message contains our specific text. + if re.search(r"Failed guarded type import with ImportError.*AbstractSetIntStr", record.getMessage()): + return False + return True + +# Attach the filter to the "sphinx" logger. +logger = logging.getLogger("sphinx") +logger.addFilter(IgnoreGuardedTypeImportFilter()) import os import sys