Skip to content

Commit 5f7db25

Browse files
author
yarden-sony
committed
update type hints
1 parent b919c3c commit 5f7db25

File tree

3 files changed

+9
-10
lines changed

3 files changed

+9
-10
lines changed

model_compression_toolkit/target_platform_capabilities/schema/schema_compatability.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
from typing import Any
15+
from typing import Any, Union
1616

1717
import model_compression_toolkit.target_platform_capabilities.schema.v1 as schema_v1
1818
import model_compression_toolkit.target_platform_capabilities.schema.v2 as schema_v2
1919
import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema
2020

21-
ALL_SCHEMA_VERSIONS = [schema_v1]
21+
ALL_SCHEMA_VERSIONS = [schema_v1] # needs to be updated with all active schema versions
22+
all_tpc_types = tuple([s.TargetPlatformCapabilities for s in ALL_SCHEMA_VERSIONS])
2223

2324

2425
def is_tpc_instance(tpc_obj_or_path: Any) -> bool:
@@ -27,11 +28,10 @@ def is_tpc_instance(tpc_obj_or_path: Any) -> bool:
2728
:param tpc_obj_or_path: Object to check its type
2829
:return: True if the given object is an instance of a TargetPlatformCapabilities, False otherwise
2930
"""
30-
all_tpc_types = [s.TargetPlatformCapabilities for s in ALL_SCHEMA_VERSIONS]
3131
return type(tpc_obj_or_path) in all_tpc_types
3232

3333

34-
def _schema_v1_to_v2(tpc: schema_v1.TargetPlatformCapabilities) -> schema.TargetPlatformCapabilities:
34+
def _schema_v1_to_v2(tpc: schema_v1.TargetPlatformCapabilities) -> schema_v2.TargetPlatformCapabilities:
3535
"""
3636
Converts given tpc of schema version 1 to schema version 2
3737
:return: TargetPlatformCapabilities instance of of schema version 2
@@ -46,13 +46,13 @@ def _schema_v1_to_v2(tpc: schema_v1.TargetPlatformCapabilities) -> schema.Target
4646
add_metadata=tpc.add_metadata)
4747

4848

49-
def tpc_to_current_schema_version(tpc: schema.TargetPlatformCapabilities):
49+
def tpc_to_current_schema_version(tpc: Union[all_tpc_types]) -> schema.TargetPlatformCapabilities:
5050
"""
5151
Given tpc instance of some schema version, convert it to the current MCT schema version.
5252
5353
In case a new schema is added to MCT, need to add a conversion function from the previous version to the new
5454
version, e.g. if the current schema version was updated from v4 to v5, need to add _schema_v4_to_v5 function to
55-
this file, than and add it to the conversion_map.
55+
this file, and add it to the conversion_map.
5656
5757
:param tpc: TargetPlatformCapabilities of some schema version
5858
:return: TargetPlatformCapabilities with the current MCT schema version

model_compression_toolkit/target_platform_capabilities/schema/v1.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -697,5 +697,4 @@ def show(self):
697697
"""
698698
Display the TargetPlatformCapabilities.
699699
"""
700-
pprint.pprint(self.get_info(), sort_dicts=False)
701-
700+
pprint.pprint(self.get_info(), sort_dicts=False)

model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema
1919
from model_compression_toolkit.target_platform_capabilities.schema.schema_compatability import is_tpc_instance, \
20-
tpc_to_current_schema_version
20+
tpc_to_current_schema_version, all_tpc_types
2121

2222

2323
def _get_tpc_from_json(tpc_path):
@@ -42,7 +42,7 @@ def _get_tpc_from_json(tpc_path):
4242
raise ValueError(f"Unexpected error while initializing TargetPlatformCapabilities: {e}.") from e
4343

4444

45-
def load_target_platform_capabilities(tpc_obj_or_path: Union[schema.TargetPlatformCapabilities, str]) -> schema.TargetPlatformCapabilities:
45+
def load_target_platform_capabilities(tpc_obj_or_path: Union[all_tpc_types] | str) -> schema.TargetPlatformCapabilities:
4646
"""
4747
Parses the tpc input, which can be either a TargetPlatformCapabilities object
4848
or a string path to a JSON file.

0 commit comments

Comments
 (0)