1+ # Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+ # ==============================================================================
15+ import pprint
16+ from enum import Enum
17+ from typing import Dict , Any , Tuple , Optional
18+
19+ from pydantic import BaseModel , root_validator
20+
21+ from mct_quantizers import QuantizationMethod
22+ from model_compression_toolkit .constants import FLOAT_BITWIDTH
23+ from model_compression_toolkit .logger import Logger
24+ from model_compression_toolkit .target_platform_capabilities .schema .v1 import (
25+ Signedness ,
26+ AttributeQuantizationConfig ,
27+ OpQuantizationConfig ,
28+ QuantizationConfigOptions ,
29+ TargetPlatformModelComponent ,
30+ OperatorsSetBase ,
31+ OperatorsSet ,
32+ OperatorSetGroup ,
33+ Fusing )
34+
35+
36+ class OperatorSetNames (str , Enum ):
37+ CONV = "Conv"
38+ DEPTHWISE_CONV = "DepthwiseConv2D"
39+ CONV_TRANSPOSE = "ConvTranspose"
40+ FULLY_CONNECTED = "FullyConnected"
41+ CONCATENATE = "Concatenate"
42+ STACK = "Stack"
43+ UNSTACK = "Unstack"
44+ GATHER = "Gather"
45+ EXPAND = "Expend"
46+ BATCH_NORM = "BatchNorm"
47+ L2NORM = "L2Norm"
48+ RELU = "ReLU"
49+ RELU6 = "ReLU6"
50+ LEAKY_RELU = "LeakyReLU"
51+ ELU = "Elu"
52+ HARD_TANH = "HardTanh"
53+ ADD = "Add"
54+ SUB = "Sub"
55+ MUL = "Mul"
56+ DIV = "Div"
57+ MIN = "Min"
58+ MAX = "Max"
59+ PRELU = "PReLU"
60+ ADD_BIAS = "AddBias"
61+ SWISH = "Swish"
62+ SIGMOID = "Sigmoid"
63+ SOFTMAX = "Softmax"
64+ LOG_SOFTMAX = "LogSoftmax"
65+ TANH = "Tanh"
66+ GELU = "Gelu"
67+ HARDSIGMOID = "HardSigmoid"
68+ HARDSWISH = "HardSwish"
69+ FLATTEN = "Flatten"
70+ GET_ITEM = "GetItem"
71+ RESHAPE = "Reshape"
72+ UNSQUEEZE = "Unsqueeze"
73+ SQUEEZE = "Squeeze"
74+ PERMUTE = "Permute"
75+ TRANSPOSE = "Transpose"
76+ DROPOUT = "Dropout"
77+ SPLIT_CHUNK = "SplitChunk"
78+ MAXPOOL = "MaxPool"
79+ AVGPOOL = "AvgPool"
80+ SIZE = "Size"
81+ SHAPE = "Shape"
82+ EQUAL = "Equal"
83+ ARGMAX = "ArgMax"
84+ TOPK = "TopK"
85+ FAKE_QUANT = "FakeQuant"
86+ COMBINED_NON_MAX_SUPPRESSION = "CombinedNonMaxSuppression"
87+ BOX_DECODE = "BoxDecode"
88+ ZERO_PADDING2D = "ZeroPadding2D"
89+ CAST = "Cast"
90+ RESIZE = "Resize"
91+ PAD = "Pad"
92+ FOLD = "Fold"
93+ STRIDED_SLICE = "StridedSlice"
94+ SSD_POST_PROCESS = "SSDPostProcess"
95+
96+ @classmethod
97+ def get_values (cls ):
98+ return [v .value for v in cls ]
99+
100+
101+ class TargetPlatformCapabilities (BaseModel ):
102+ """
103+ Represents the hardware configuration used for quantized model inference.
104+
105+ Attributes:
106+ default_qco (QuantizationConfigOptions): Default quantization configuration options for the model.
107+ operator_set (Optional[Tuple[OperatorsSet, ...]]): Tuple of operator sets within the model.
108+ fusing_patterns (Optional[Tuple[Fusing, ...]]): Tuple of fusing patterns for the model.
109+ tpc_minor_version (Optional[int]): Minor version of the Target Platform Configuration.
110+ tpc_patch_version (Optional[int]): Patch version of the Target Platform Configuration.
111+ tpc_platform_type (Optional[str]): Type of the platform for the Target Platform Configuration.
112+ add_metadata (bool): Flag to determine if metadata should be added.
113+ name (str): Name of the Target Platform Model.
114+ is_simd_padding (bool): Indicates if SIMD padding is applied.
115+ SCHEMA_VERSION (int): Version of the schema for the Target Platform Model.
116+ """
117+ default_qco : QuantizationConfigOptions
118+ operator_set : Optional [Tuple [OperatorsSet , ...]]
119+ fusing_patterns : Optional [Tuple [Fusing , ...]]
120+ tpc_minor_version : Optional [int ]
121+ tpc_patch_version : Optional [int ]
122+ tpc_platform_type : Optional [str ]
123+ add_metadata : bool = True
124+ name : Optional [str ] = "default_tpc"
125+ is_simd_padding : bool = False
126+
127+ SCHEMA_VERSION : int = 2
128+
129+ class Config :
130+ frozen = True
131+
132+ @root_validator (allow_reuse = True )
133+ def validate_after_initialization (cls , values : Dict [str , Any ]) -> Dict [str , Any ]:
134+ """
135+ Perform validation after the model has been instantiated.
136+
137+ Args:
138+ values (Dict[str, Any]): The instantiated target platform model.
139+
140+ Returns:
141+ Dict[str, Any]: The validated values.
142+ """
143+ # Validate `default_qco`
144+ default_qco = values .get ('default_qco' )
145+ if len (default_qco .quantization_configurations ) != 1 :
146+ Logger .critical ("Default QuantizationConfigOptions must contain exactly one option." ) # pragma: no cover
147+
148+ # Validate `operator_set` uniqueness
149+ operator_set = values .get ('operator_set' )
150+ if operator_set is not None :
151+ opsets_names = [
152+ op .name .value if isinstance (op .name , OperatorSetNames ) else op .name
153+ for op in operator_set
154+ ]
155+ if len (set (opsets_names )) != len (opsets_names ):
156+ Logger .critical ("Operator Sets must have unique names." ) # pragma: no cover
157+
158+ return values
159+
160+ def get_info (self ) -> Dict [str , Any ]:
161+ """
162+ Get a dictionary summarizing the TargetPlatformCapabilities properties.
163+
164+ Returns:
165+ Dict[str, Any]: Summary of the TargetPlatformCapabilities properties.
166+ """
167+ return {
168+ "Model name" : self .name ,
169+ "Operators sets" : [o .get_info () for o in self .operator_set ] if self .operator_set else [],
170+ "Fusing patterns" : [f .get_info () for f in self .fusing_patterns ] if self .fusing_patterns else [],
171+ }
172+
173+ def show (self ):
174+ """
175+ Display the TargetPlatformCapabilities.
176+ """
177+ pprint .pprint (self .get_info (), sort_dicts = False )
0 commit comments