1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414# ==============================================================================
15- from typing import Any
15+ from typing import Any , Union
1616
1717import model_compression_toolkit .target_platform_capabilities .schema .v1 as schema_v1
1818import model_compression_toolkit .target_platform_capabilities .schema .v2 as schema_v2
1919import model_compression_toolkit .target_platform_capabilities .schema .mct_current_schema as schema
2020
21- ALL_SCHEMA_VERSIONS = [schema_v1 ]
21+ ALL_SCHEMA_VERSIONS = [schema_v1 ] # needs to be updated with all active schema versions
22+ all_tpc_types = tuple ([s .TargetPlatformCapabilities for s in ALL_SCHEMA_VERSIONS ])
2223
2324
2425def is_tpc_instance (tpc_obj_or_path : Any ) -> bool :
@@ -27,11 +28,10 @@ def is_tpc_instance(tpc_obj_or_path: Any) -> bool:
2728 :param tpc_obj_or_path: Object to check its type
2829 :return: True if the given object is an instance of a TargetPlatformCapabilities, False otherwise
2930 """
30- all_tpc_types = [s .TargetPlatformCapabilities for s in ALL_SCHEMA_VERSIONS ]
3131 return type (tpc_obj_or_path ) in all_tpc_types
3232
3333
34- def _schema_v1_to_v2 (tpc : schema_v1 .TargetPlatformCapabilities ) -> schema .TargetPlatformCapabilities :
34+ def _schema_v1_to_v2 (tpc : schema_v1 .TargetPlatformCapabilities ) -> schema_v2 .TargetPlatformCapabilities :
3535 """
3636 Converts given tpc of schema version 1 to schema version 2
3737 :return: TargetPlatformCapabilities instance of of schema version 2
@@ -46,13 +46,13 @@ def _schema_v1_to_v2(tpc: schema_v1.TargetPlatformCapabilities) -> schema.Target
4646 add_metadata = tpc .add_metadata )
4747
4848
49- def tpc_to_current_schema_version (tpc : schema .TargetPlatformCapabilities ) :
49+ def tpc_to_current_schema_version (tpc : Union [ all_tpc_types ]) -> schema .TargetPlatformCapabilities :
5050 """
5151 Given tpc instance of some schema version, convert it to the current MCT schema version.
5252
5353 In case a new schema is added to MCT, need to add a conversion function from the previous version to the new
5454 version, e.g. if the current schema version was updated from v4 to v5, need to add _schema_v4_to_v5 function to
55- this file, than and add it to the conversion_map.
55+ this file, and add it to the conversion_map.
5656
5757 :param tpc: TargetPlatformCapabilities of some schema version
5858 :return: TargetPlatformCapabilities with the current MCT schema version
0 commit comments