Skip to content

Commit 55ee9c8

Browse files
author
yarden-sony
committed
remove ClassVar
1 parent 663d8fc commit 55ee9c8

File tree

5 files changed

+25
-16
lines changed

5 files changed

+25
-16
lines changed

model_compression_toolkit/target_platform_capabilities/schema/schema_compatability.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,18 @@ def _schema_v1_to_v2(tpc: schema_v1.TargetPlatformCapabilities) -> schema_v2.Tar
4646
tpc_platform_type=tpc.tpc_platform_type,
4747
add_metadata=tpc.add_metadata)
4848

49+
def get_conversion_map() -> dict:
50+
"""
51+
Retrieves the schema conversion map.
52+
:return: A dictionary where:
53+
- Keys representing supported source schema versions.
54+
- Values: Callable functions that take tpc in one schema version and return it in the next (higher) version
55+
"""
56+
conversion_map = {
57+
1: _schema_v1_to_v2,
58+
}
59+
return conversion_map
60+
4961

5062
def tpc_to_current_schema_version(tpc: Union[all_tpc_types]) -> schema.TargetPlatformCapabilities:
5163
"""
@@ -58,10 +70,8 @@ def tpc_to_current_schema_version(tpc: Union[all_tpc_types]) -> schema.TargetPla
5870
:param tpc: TargetPlatformCapabilities of some schema version
5971
:return: TargetPlatformCapabilities with the current MCT schema version
6072
"""
61-
conversion_map = {
62-
schema_v1.TargetPlatformCapabilities.SCHEMA_VERSION: _schema_v1_to_v2,
63-
}
64-
while tpc.SCHEMA_VERSION < schema.TargetPlatformCapabilities.SCHEMA_VERSION:
73+
conversion_map = get_conversion_map()
74+
while not isinstance(tpc, schema.TargetPlatformCapabilities):
6575
if tpc.SCHEMA_VERSION not in conversion_map:
6676
raise KeyError(f"TPC using schema version {tpc.SCHEMA_VERSION} which is not in schemas conversion map. "
6777
f"Make sure the schema version is supported, or add it in case it's a new schema version")

model_compression_toolkit/target_platform_capabilities/schema/v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# ==============================================================================
1515
import pprint
1616
from enum import Enum
17-
from typing import Dict, Any, Union, Tuple, List, Optional, Literal, Annotated, ClassVar
17+
from typing import Dict, Any, Union, Tuple, List, Optional, Literal, Annotated
1818

1919
from pydantic import BaseModel, Field, root_validator, validator, PositiveInt
2020

@@ -647,7 +647,7 @@ class TargetPlatformCapabilities(BaseModel):
647647
name: Optional[str] = "default_tpc"
648648
is_simd_padding: bool = False
649649

650-
SCHEMA_VERSION: ClassVar[int] = 1
650+
SCHEMA_VERSION: int = 1
651651

652652
class Config:
653653
frozen = True

model_compression_toolkit/target_platform_capabilities/schema/v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# ==============================================================================
1515
import pprint
1616
from enum import Enum
17-
from typing import Dict, Any, Tuple, Optional, ClassVar
17+
from typing import Dict, Any, Tuple, Optional
1818

1919
from pydantic import BaseModel, root_validator
2020

@@ -60,7 +60,7 @@ class TargetPlatformCapabilities(BaseModel):
6060
name: Optional[str] = "default_tpc"
6161
is_simd_padding: bool = False
6262

63-
SCHEMA_VERSION: ClassVar[int] = 2
63+
SCHEMA_VERSION: int = 2
6464

6565
class Config:
6666
frozen = True

model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def load_target_platform_capabilities(tpc_obj_or_path: Union[tpc_or_str_type]) -
6868
f"but received type '{type(tpc_obj_or_path).__name__}'."
6969
)
7070

71-
if tpc.SCHEMA_VERSION == schema.TargetPlatformCapabilities.SCHEMA_VERSION:
71+
if isinstance(tpc.SCHEMA_VERSION, schema.TargetPlatformCapabilities): # if tpc is of current schema version
7272
return tpc
7373
return tpc_to_current_schema_version(tpc)
7474

tests_pytest/common_tests/unit_tests/target_platform_capabilities/test_tpc.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema
1818
from model_compression_toolkit.core.common import BaseNode
1919
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR
20-
from model_compression_toolkit.target_platform_capabilities.schema.schema_compatability import ALL_SCHEMA_VERSIONS
20+
from model_compression_toolkit.target_platform_capabilities.schema.schema_compatability import ALL_SCHEMA_VERSIONS, \
21+
all_tpc_types, tpc_to_current_schema_version
2122
from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import get_config_options_by_operators_set, is_opset_in_model
2223
from model_compression_toolkit.target_platform_capabilities.tpc_io_handler import load_target_platform_capabilities, \
2324
export_target_platform_capabilities
@@ -107,17 +108,15 @@ def test_valid_model_object(self, tpc):
107108
result = load_target_platform_capabilities(tpc)
108109
assert result == tpc
109110

110-
def test_new_schema(self):
111+
def test_new_schema(self, tpc):
111112
"""Tests that current schema is in all schemas list. This test validates new schema was added properly."""
112-
current_version = schema.TargetPlatformCapabilities.SCHEMA_VERSION
113-
all_supported_versions = [s.TargetPlatformCapabilities.SCHEMA_VERSION for s in ALL_SCHEMA_VERSIONS]
114-
assert current_version in all_supported_versions, "Current schema need to be added to ALL_SCHEMA_VERSIONS"
113+
assert type(tpc) in all_tpc_types, "Current schema need to be added to ALL_SCHEMA_VERSIONS"
115114

116115
def test_schema_compatibility(self, tpc_by_schema_version):
117116
"""Tests that tpc of any schema version is supported and can be converted into current schema version"""
118117
tpc_by_schema = tpc_by_schema_version
119-
result = load_target_platform_capabilities(tpc_by_schema)
120-
assert result.SCHEMA_VERSION == schema.TargetPlatformCapabilities.SCHEMA_VERSION, \
118+
result = tpc_to_current_schema_version(tpc_by_schema)
119+
assert isinstance(result, schema.TargetPlatformCapabilities), \
121120
f"Make sure schema version {result.SCHEMA_VERSION} can be converted into current schema version {result.SCHEMA_VERSION}"
122121

123122
def test_invalid_json_parsing(self, tmp_invalid_json):

0 commit comments

Comments
 (0)