diff --git a/.github/workflows/run_tests_python310_keras212.yml b/.github/workflows/run_tests_python310_keras212.yml
deleted file mode 100644
index 172faaa31..000000000
--- a/.github/workflows/run_tests_python310_keras212.yml
+++ /dev/null
@@ -1,19 +0,0 @@
-name: Python 3.10, Keras 2.12
-on:
- workflow_dispatch: # Allow manual triggers
- schedule:
- - cron: 0 0 * * *
- pull_request:
- branches:
- - main
-
-concurrency:
- group: ${{ github.workflow }}-${{ github.ref }}
- cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
-
-jobs:
- run-tests:
- uses: ./.github/workflows/run_keras_tests.yml
- with:
- python-version: "3.10"
- tf-version: "2.12.*"
\ No newline at end of file
diff --git a/.github/workflows/run_tests_python310_keras213.yml b/.github/workflows/run_tests_python310_keras213.yml
deleted file mode 100644
index 7a90da443..000000000
--- a/.github/workflows/run_tests_python310_keras213.yml
+++ /dev/null
@@ -1,19 +0,0 @@
-name: Python 3.10, Keras 2.13
-on:
- workflow_dispatch: # Allow manual triggers
- schedule:
- - cron: 0 0 * * *
- pull_request:
- branches:
- - main
-
-concurrency:
- group: ${{ github.workflow }}-${{ github.ref }}
- cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
-
-jobs:
- run-tests:
- uses: ./.github/workflows/run_keras_tests.yml
- with:
- python-version: "3.10"
- tf-version: "2.13.*"
\ No newline at end of file
diff --git a/.github/workflows/run_tests_python311_keras212.yml b/.github/workflows/run_tests_python311_keras212.yml
deleted file mode 100644
index b129a5843..000000000
--- a/.github/workflows/run_tests_python311_keras212.yml
+++ /dev/null
@@ -1,19 +0,0 @@
-name: Python 3.11, Keras 2.12
-on:
- workflow_dispatch: # Allow manual triggers
- schedule:
- - cron: 0 0 * * *
- pull_request:
- branches:
- - main
-
-concurrency:
- group: ${{ github.workflow }}-${{ github.ref }}
- cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
-
-jobs:
- run-tests:
- uses: ./.github/workflows/run_keras_tests.yml
- with:
- python-version: "3.11"
- tf-version: "2.12.*"
\ No newline at end of file
diff --git a/.github/workflows/run_tests_python311_keras213.yml b/.github/workflows/run_tests_python311_keras213.yml
deleted file mode 100644
index 635534bfc..000000000
--- a/.github/workflows/run_tests_python311_keras213.yml
+++ /dev/null
@@ -1,19 +0,0 @@
-name: Python 3.11, Keras 2.13
-on:
- workflow_dispatch: # Allow manual triggers
- schedule:
- - cron: 0 0 * * *
- pull_request:
- branches:
- - main
-
-concurrency:
- group: ${{ github.workflow }}-${{ github.ref }}
- cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
-
-jobs:
- run-tests:
- uses: ./.github/workflows/run_keras_tests.yml
- with:
- python-version: "3.11"
- tf-version: "2.13.*"
\ No newline at end of file
diff --git a/.github/workflows/run_tests_python39_keras212.yml b/.github/workflows/run_tests_python39_keras212.yml
deleted file mode 100644
index 5f42dd35b..000000000
--- a/.github/workflows/run_tests_python39_keras212.yml
+++ /dev/null
@@ -1,19 +0,0 @@
-name: Python 3.9, Keras 2.12
-on:
- workflow_dispatch: # Allow manual triggers
- schedule:
- - cron: 0 0 * * *
- pull_request:
- branches:
- - main
-
-concurrency:
- group: ${{ github.workflow }}-${{ github.ref }}
- cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
-
-jobs:
- run-tests:
- uses: ./.github/workflows/run_keras_tests.yml
- with:
- python-version: "3.9"
- tf-version: "2.12.*"
\ No newline at end of file
diff --git a/.github/workflows/run_tests_python39_keras213.yml b/.github/workflows/run_tests_python39_keras213.yml
deleted file mode 100644
index fa7c34390..000000000
--- a/.github/workflows/run_tests_python39_keras213.yml
+++ /dev/null
@@ -1,19 +0,0 @@
-name: Python 3.9, Keras 2.13
-on:
- workflow_dispatch: # Allow manual triggers
- schedule:
- - cron: 0 0 * * *
- pull_request:
- branches:
- - main
-
-concurrency:
- group: ${{ github.workflow }}-${{ github.ref }}
- cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
-
-jobs:
- run-tests:
- uses: ./.github/workflows/run_keras_tests.yml
- with:
- python-version: "3.9"
- tf-version: "2.13.*"
\ No newline at end of file
diff --git a/.github/workflows/run_tests_suite_coverage.yml b/.github/workflows/run_tests_suite_coverage.yml
index d8b4947bf..d916c0da9 100644
--- a/.github/workflows/run_tests_suite_coverage.yml
+++ b/.github/workflows/run_tests_suite_coverage.yml
@@ -65,7 +65,7 @@ jobs:
python -m venv tf_env
source tf_env/bin/activate
python -m pip install --upgrade pip
- pip install -r requirements.txt tensorflow==2.13.* coverage pytest pytest-mock
+ pip install -r requirements.txt tensorflow==2.15.* coverage pytest pytest-mock
- name: Run TensorFlow tests (unittest)
run: |
diff --git a/README.md b/README.md
index 40cbd3ae5..5caad0152 100644
--- a/README.md
+++ b/README.md
@@ -30,7 +30,7 @@ ________________________________________________________________________________
##
Getting Started
### Quick Installation
-Pip install the model compression toolkit package in a Python>=3.9 environment with PyTorch>=2.1 or Tensorflow>=2.12.
+Pip install the model compression toolkit package in a Python>=3.9 environment with PyTorch>=2.1 or Tensorflow>=2.14.
```
pip install model-compression-toolkit
```
@@ -137,11 +137,11 @@ Currently, MCT is being tested on various Python, Pytorch and TensorFlow version
| Python 3.11 | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch22.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch23.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch24.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch25.yml) |
| Python 3.12 | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch22.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch23.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch24.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch25.yml) |
-| | TensorFlow 2.12 | TensorFlow 2.13 | TensorFlow 2.14 | TensorFlow 2.15 |
+| | TensorFlow 2.14 | TensorFlow 2.15 |
|-------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
-| Python 3.9 | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras212.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras213.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras214.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras215.yml) |
-| Python 3.10 | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras212.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras213.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras214.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras215.yml) |
-| Python 3.11 | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras212.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras213.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras214.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras215.yml) |
+| Python 3.9 | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras214.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras215.yml) |
+| Python 3.10 | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras214.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras215.yml) |
+| Python 3.11 | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras214.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras215.yml) |
diff --git a/docsrc/source/conf.py b/docsrc/source/conf.py
index c375580de..e2eff2bf4 100644
--- a/docsrc/source/conf.py
+++ b/docsrc/source/conf.py
@@ -10,6 +10,20 @@
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
+import logging
+import re
+
+# Define a filter that suppresses log records matching our unwanted warning message.
+class IgnoreGuardedTypeImportFilter(logging.Filter):
+ def filter(self, record):
+ # Return False (i.e. ignore) if the message contains our specific text.
+ if re.search(r"Failed guarded type import with ImportError.*AbstractSetIntStr", record.getMessage()):
+ return False
+ return True
+
+# Attach the filter to the "sphinx" logger.
+logger = logging.getLogger("sphinx")
+logger.addFilter(IgnoreGuardedTypeImportFilter())
import os
import sys
diff --git a/model_compression_toolkit/target_platform_capabilities/schema/v1.py b/model_compression_toolkit/target_platform_capabilities/schema/v1.py
index fbbf5a2fb..284b56cd2 100644
--- a/model_compression_toolkit/target_platform_capabilities/schema/v1.py
+++ b/model_compression_toolkit/target_platform_capabilities/schema/v1.py
@@ -16,7 +16,8 @@
from enum import Enum
from typing import Dict, Any, Union, Tuple, List, Optional, Literal, Annotated
-from pydantic import BaseModel, Field, root_validator, validator, PositiveInt
+from pydantic import BaseModel, Field, root_validator, validator, PositiveInt, ConfigDict, field_validator, \
+ model_validator
from mct_quantizers import QuantizationMethod
from model_compression_toolkit.constants import FLOAT_BITWIDTH
@@ -118,9 +119,7 @@ class AttributeQuantizationConfig(BaseModel):
enable_weights_quantization: bool = False
lut_values_bitwidth: Optional[int] = None
- class Config:
- # Makes the model immutable (frozen)
- frozen = True
+ model_config = ConfigDict(frozen=True)
@property
def field_names(self) -> list:
@@ -137,7 +136,7 @@ def clone_and_edit(self, **kwargs) -> 'AttributeQuantizationConfig':
Returns:
AttributeQuantizationConfig: A new instance of AttributeQuantizationConfig with updated attributes.
"""
- return self.copy(update=kwargs)
+ return self.model_copy(update=kwargs)
class OpQuantizationConfig(BaseModel):
@@ -164,15 +163,14 @@ class OpQuantizationConfig(BaseModel):
supported_input_activation_n_bits: Union[int, Tuple[int, ...]]
enable_activation_quantization: bool
quantization_preserving: bool
- fixed_scale: Optional[float]
- fixed_zero_point: Optional[int]
- simd_size: Optional[int]
signedness: Signedness
+ fixed_scale: Optional[float] = None
+ fixed_zero_point: Optional[int] = None
+ simd_size: Optional[int] = None
- class Config:
- frozen = True
+ model_config = ConfigDict(frozen=True)
- @validator('supported_input_activation_n_bits', pre=True, allow_reuse=True)
+ @field_validator('supported_input_activation_n_bits', mode='before')
def validate_supported_input_activation_n_bits(cls, v):
"""
Validate and process the supported_input_activation_n_bits field.
@@ -199,9 +197,9 @@ def get_info(self) -> Dict[str, Any]:
return self.dict() # pragma: no cover
def clone_and_edit(
- self,
- attr_to_edit: Dict[str, Dict[str, Any]] = {},
- **kwargs: Any
+ self,
+ attr_to_edit: Dict[str, Dict[str, Any]] = {},
+ **kwargs: Any
) -> 'OpQuantizationConfig':
"""
Clone the quantization config and edit some of its attributes.
@@ -215,17 +213,17 @@ def clone_and_edit(
OpQuantizationConfig: Edited quantization configuration.
"""
# Clone and update top-level attributes
- updated_config = self.copy(update=kwargs)
+ updated_config = self.model_copy(update=kwargs)
# Clone and update nested immutable dataclasses in `attr_weights_configs_mapping`
updated_attr_mapping = {
attr_name: (attr_cfg.clone_and_edit(**attr_to_edit[attr_name])
- if attr_name in attr_to_edit else attr_cfg)
+ if attr_name in attr_to_edit else attr_cfg)
for attr_name, attr_cfg in updated_config.attr_weights_configs_mapping.items()
}
# Return a new instance with the updated attribute mapping
- return updated_config.copy(update={'attr_weights_configs_mapping': updated_attr_mapping})
+ return updated_config.model_copy(update={'attr_weights_configs_mapping': updated_attr_mapping})
class QuantizationConfigOptions(BaseModel):
@@ -239,10 +237,9 @@ class QuantizationConfigOptions(BaseModel):
quantization_configurations: Tuple[OpQuantizationConfig, ...]
base_config: Optional[OpQuantizationConfig] = None
- class Config:
- frozen = True
+ model_config = ConfigDict(frozen=True)
- @root_validator(pre=True, allow_reuse=True)
+ @model_validator(mode="before")
def validate_and_set_base_config(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""
Validate and set the base_config based on quantization_configurations.
@@ -282,12 +279,6 @@ def validate_and_set_base_config(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"'base_config' must be included in the quantization config options."
) # pragma: no cover
- # if num_configs == 1:
- # if base_config != quantization_configurations[0]:
- # Logger.critical(
- # "'base_config' should be the same as the sole item in 'quantization_configurations'."
- # ) # pragma: no cover
-
values['base_config'] = base_config
# When loading from JSON, lists are returned. If the value is a list, convert it to a tuple.
@@ -312,7 +303,7 @@ def clone_and_edit(self, **kwargs) -> 'QuantizationConfigOptions':
# Clone and update all configurations
updated_configs = tuple(cfg.clone_and_edit(**kwargs) for cfg in self.quantization_configurations)
- return self.copy(update={
+ return self.model_copy(update={
'base_config': updated_base_config,
'quantization_configurations': updated_configs
})
@@ -360,7 +351,7 @@ def clone_and_edit_weight_attribute(
updated_cfg = qc.clone_and_edit(attr_weights_configs_mapping=updated_attr_mapping)
updated_configs.append(updated_cfg)
- return self.copy(update={
+ return self.model_copy(update={
'base_config': updated_base_config,
'quantization_configurations': tuple(updated_configs)
})
@@ -398,7 +389,7 @@ def clone_and_map_weights_attr_keys(
updated_cfg = qc.clone_and_edit(attr_weights_configs_mapping=new_attr_mapping)
updated_configs.append(updated_cfg)
- return self.copy(update={
+ return self.model_copy(update={
'base_config': new_base_config,
'quantization_configurations': tuple(updated_configs)
})
@@ -412,12 +403,12 @@ def get_info(self) -> Dict[str, Any]:
"""
return {f'option_{i}': cfg.get_info() for i, cfg in enumerate(self.quantization_configurations)}
+
class TargetPlatformModelComponent(BaseModel):
"""
Component of TargetPlatformCapabilities (Fusing, OperatorsSet, etc.).
"""
- class Config:
- frozen = True
+ model_config = ConfigDict(frozen=True)
class OperatorsSetBase(TargetPlatformModelComponent):
@@ -444,8 +435,7 @@ class OperatorsSet(OperatorsSetBase):
# Define a private attribute _type
type: Literal["OperatorsSet"] = "OperatorsSet"
- class Config:
- frozen = True
+ model_config = ConfigDict(frozen=True)
def get_info(self) -> Dict[str, Any]:
"""
@@ -471,10 +461,9 @@ class OperatorSetGroup(OperatorsSetBase):
# Define a private attribute _type
type: Literal["OperatorSetGroup"] = "OperatorSetGroup"
- class Config:
- frozen = True
+ model_config = ConfigDict(frozen=True)
- @root_validator(pre=True, allow_reuse=True)
+ @model_validator(mode="before")
def validate_and_set_name(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""
Validate the input and set the concatenated name based on the operators_set.
@@ -512,6 +501,7 @@ def get_info(self) -> Dict[str, Any]:
"operators_set": [op.get_info() for op in self.operators_set]
}
+
class Fusing(TargetPlatformModelComponent):
"""
Fusing defines a tuple of operators that should be combined and treated as a single operator,
@@ -525,10 +515,9 @@ class Fusing(TargetPlatformModelComponent):
operator_groups: Tuple[Annotated[Union[OperatorsSet, OperatorSetGroup], Field(discriminator='type')], ...]
name: Optional[str] = None # Will be set in the validator if not given.
- class Config:
- frozen = True
+ model_config = ConfigDict(frozen=True)
- @root_validator(pre=True, allow_reuse=True)
+ @model_validator(mode="before")
def validate_and_set_name(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""
Validate the operator_groups and set the name by concatenating operator group names.
@@ -555,24 +544,15 @@ def validate_and_set_name(cls, values: Dict[str, Any]) -> Dict[str, Any]:
return values
- @root_validator(allow_reuse=True)
- def validate_after_initialization(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ @model_validator(mode="after")
+ def validate_after_initialization(cls, model: 'Fusing') -> Any:
"""
Perform validation after the model has been instantiated.
-
- Args:
- values (Dict[str, Any]): The instantiated fusing.
-
- Returns:
- Dict[str, Any]: The validated values.
+ Ensures that there are at least two operator groups.
"""
- operator_groups = values.get('operator_groups')
-
- # Validate that there are at least two operator groups
- if len(operator_groups) < 2:
+ if len(model.operator_groups) < 2:
Logger.critical("Fusing cannot be created for a single operator.") # pragma: no cover
-
- return values
+ return model
def contains(self, other: Any) -> bool:
"""
@@ -621,6 +601,7 @@ def get_info(self) -> Union[Dict[str, str], str]:
for x in self.operator_groups
])
+
class TargetPlatformCapabilities(BaseModel):
"""
Represents the hardware configuration used for quantized model inference.
@@ -638,38 +619,37 @@ class TargetPlatformCapabilities(BaseModel):
SCHEMA_VERSION (int): Version of the schema for the Target Platform Model.
"""
default_qco: QuantizationConfigOptions
- operator_set: Optional[Tuple[OperatorsSet, ...]]
- fusing_patterns: Optional[Tuple[Fusing, ...]]
- tpc_minor_version: Optional[int]
- tpc_patch_version: Optional[int]
- tpc_platform_type: Optional[str]
+ operator_set: Optional[Tuple[OperatorsSet, ...]] = None
+ fusing_patterns: Optional[Tuple[Fusing, ...]] = None
+ tpc_minor_version: Optional[int] = None
+ tpc_patch_version: Optional[int] = None
+ tpc_platform_type: Optional[str] = None
add_metadata: bool = True
name: Optional[str] = "default_tpc"
is_simd_padding: bool = False
SCHEMA_VERSION: int = 1
- class Config:
- frozen = True
+ model_config = ConfigDict(frozen=True)
- @root_validator(allow_reuse=True)
- def validate_after_initialization(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ @model_validator(mode="after")
+ def validate_after_initialization(cls, model: 'TargetPlatformCapabilities') -> Any:
"""
Perform validation after the model has been instantiated.
Args:
- values (Dict[str, Any]): The instantiated target platform model.
+ model (TargetPlatformCapabilities): The instantiated target platform model.
Returns:
- Dict[str, Any]: The validated values.
+ TargetPlatformCapabilities: The validated model.
"""
# Validate `default_qco`
- default_qco = values.get('default_qco')
+ default_qco = model.default_qco
if len(default_qco.quantization_configurations) != 1:
Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") # pragma: no cover
# Validate `operator_set` uniqueness
- operator_set = values.get('operator_set')
+ operator_set = model.operator_set
if operator_set is not None:
opsets_names = [
op.name.value if isinstance(op.name, OperatorSetNames) else op.name
@@ -678,7 +658,7 @@ def validate_after_initialization(cls, values: Dict[str, Any]) -> Dict[str, Any]
if len(set(opsets_names)) != len(opsets_names):
Logger.critical("Operator Sets must have unique names.") # pragma: no cover
- return values
+ return model
def get_info(self) -> Dict[str, Any]:
"""
diff --git a/model_compression_toolkit/target_platform_capabilities/schema/v2.py b/model_compression_toolkit/target_platform_capabilities/schema/v2.py
index 475d4379c..bad8cf6b4 100644
--- a/model_compression_toolkit/target_platform_capabilities/schema/v2.py
+++ b/model_compression_toolkit/target_platform_capabilities/schema/v2.py
@@ -13,10 +13,9 @@
# limitations under the License.
# ==============================================================================
import pprint
-from enum import Enum
from typing import Dict, Any, Tuple, Optional
-from pydantic import BaseModel, root_validator
+from pydantic import BaseModel, root_validator, model_validator, ConfigDict
from mct_quantizers import QuantizationMethod
from model_compression_toolkit.constants import FLOAT_BITWIDTH
@@ -62,27 +61,26 @@ class TargetPlatformCapabilities(BaseModel):
SCHEMA_VERSION: int = 2
- class Config:
- frozen = True
+ model_config = ConfigDict(frozen=True)
- @root_validator(allow_reuse=True)
- def validate_after_initialization(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ @model_validator(mode="after")
+ def validate_after_initialization(cls, model: 'TargetPlatformCapabilities') -> Any:
"""
Perform validation after the model has been instantiated.
Args:
- values (Dict[str, Any]): The instantiated target platform model.
+ model (TargetPlatformCapabilities): The instantiated target platform model.
Returns:
- Dict[str, Any]: The validated values.
+ TargetPlatformCapabilities: The validated model.
"""
# Validate `default_qco`
- default_qco = values.get('default_qco')
+ default_qco = model.default_qco
if len(default_qco.quantization_configurations) != 1:
Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") # pragma: no cover
# Validate `operator_set` uniqueness
- operator_set = values.get('operator_set')
+ operator_set = model.operator_set
if operator_set is not None:
opsets_names = [
op.name.value if isinstance(op.name, OperatorSetNames) else op.name
@@ -91,7 +89,7 @@ def validate_after_initialization(cls, values: Dict[str, Any]) -> Dict[str, Any]
if len(set(opsets_names)) != len(opsets_names):
Logger.critical("Operator Sets must have unique names.") # pragma: no cover
- return values
+ return model
def get_info(self) -> Dict[str, Any]:
"""
diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py b/model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py
index 653192c0e..d40905627 100644
--- a/model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py
+++ b/model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py
@@ -100,6 +100,6 @@ def export_target_platform_capabilities(model: schema.TargetPlatformCapabilities
# Export the model to JSON and write to the file
with path.open('w', encoding='utf-8') as file:
- file.write(model.json(indent=4))
+ file.write(model.model_dump_json(indent=4))
except OSError as e:
raise OSError(f"Failed to write to file '{export_path}': {e.strerror}") from e
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 1ea53aadf..8f1ee1298 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -10,5 +10,6 @@ matplotlib<3.10.0
scipy
protobuf
mct-quantizers-nightly
-pydantic<2.0
+pydantic>=2.0
sony-custom-layers-dev==0.4.0.dev6
+
diff --git a/tests_pytest/common_tests/unit_tests/core/quantization/node_quantization_config/test_node_weights_quantization_config.py b/tests_pytest/common_tests/unit_tests/core/quantization/node_quantization_config/test_node_weights_quantization_config.py
index 9c6fb660e..13a0cb0a5 100644
--- a/tests_pytest/common_tests/unit_tests/core/quantization/node_quantization_config/test_node_weights_quantization_config.py
+++ b/tests_pytest/common_tests/unit_tests/core/quantization/node_quantization_config/test_node_weights_quantization_config.py
@@ -115,7 +115,7 @@ def test_node_weights_quantization_config_op_cfg_mapping(self):
# Test using the positional attribute as the key rather than POS_ATTR; this mismatch should cause
# NodeWeightsQuantizationConfig to fall back to the default weights attribute configuration instead of
# applying the specific one.
- op_cfg = self._create_node_weights_op_cfg(pos_weight_attr=[positional_weight_attr],
+ op_cfg = self._create_node_weights_op_cfg(pos_weight_attr=[str(positional_weight_attr)],
pos_weight_attr_config=[pos_weight_attr_config],
def_weight_attr_config=def_weight_attr_config)
diff --git a/tests_pytest/common_tests/unit_tests/target_platform_capabilities/test_tpc.py b/tests_pytest/common_tests/unit_tests/target_platform_capabilities/test_tpc.py
index 1d69e29f1..03e026b32 100644
--- a/tests_pytest/common_tests/unit_tests/target_platform_capabilities/test_tpc.py
+++ b/tests_pytest/common_tests/unit_tests/target_platform_capabilities/test_tpc.py
@@ -14,6 +14,7 @@
# ==============================================================================
import os
import pytest
+from pydantic import ValidationError
import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as current_schema
from model_compression_toolkit.core.common import BaseNode
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR
@@ -165,7 +166,7 @@ def test_valid_export(self, tpc, valid_export_path):
assert os.path.exists(str(valid_export_path))
with open(str(valid_export_path), "r", encoding="utf-8") as file:
content = file.read()
- assert content == tpc.json(indent=4)
+ assert content == tpc.model_dump_json(indent=4)
def test_export_with_invalid_model(self, valid_export_path):
"""Tests that exporting an invalid model raises a ValueError."""
@@ -178,13 +179,13 @@ def test_export_with_invalid_path(self, tpc, invalid_export_path):
export_target_platform_capabilities(tpc, str(invalid_export_path))
def test_export_creates_parent_directories(self, tpc, tmp_path):
- """Tests that exporting to an invalid path raises an OSError."""
+ """Tests that exporting creates parent directories as needed."""
nested_path = tmp_path / "nested" / "directory" / "exported_model.json"
export_target_platform_capabilities(tpc, str(nested_path))
assert os.path.exists(str(nested_path))
with open(str(nested_path), "r", encoding="utf-8") as file:
content = file.read()
- assert content == tpc.json(indent=4)
+ assert content == tpc.model_dump_json(indent=4)
# Cleanup created directories
os.remove(str(nested_path))
os.rmdir(str(tmp_path / "nested" / "directory"))
@@ -200,17 +201,19 @@ def test_export_then_import(self, tpc, valid_export_path):
class TestTargetPlatformModeling:
def test_immutable_tp(self):
"""Tests that modifying an immutable TargetPlatformCapabilities instance raises an exception."""
- with pytest.raises(Exception, match='"TargetPlatformCapabilities" is immutable and does not support item assignment'):
- model = current_schema.TargetPlatformCapabilities(
- default_qco=TEST_QCO,
+ model = current_schema.TargetPlatformCapabilities(
+ default_qco=TEST_QCO,
operator_set=(current_schema.OperatorsSet(name="opset"),),
- tpc_minor_version=None,
- tpc_patch_version=None,
- tpc_platform_type=None,
- add_metadata=False
- )
+ tpc_minor_version=None,
+ tpc_patch_version=None,
+ tpc_platform_type=None,
+ add_metadata=False
+ )
+ # Expecting a TypeError or AttributeError due to immutability
+ with pytest.raises(ValidationError , match="Instance is frozen"):
model.operator_set = tuple()
+
def test_default_options_more_than_single_qc(self):
"""Tests that creating a TargetPlatformCapabilities with default_qco containing more than one configuration raises an exception."""
test_qco = current_schema.QuantizationConfigOptions(
@@ -304,7 +307,7 @@ def test_empty_qc_options(self):
def test_list_of_no_qc(self):
"""Tests that providing an invalid configuration list (non-dict values) to QuantizationConfigOptions raises an exception."""
- with pytest.raises(Exception, match="value is not a valid dict"):
+ with pytest.raises(ValidationError, match="Input should be a valid dictionary"):
current_schema.QuantizationConfigOptions(quantization_configurations=(TEST_QC, 3), base_config=TEST_QC)
def test_clone_and_edit_options(self):