Skip to content

Commit 516061b

Browse files
yarden-yagil-sonyyarden-sony
andauthored
Schema - V2 (#1396)
* add box decode to op set names --------- Co-authored-by: yarden-sony <yardeny-sony@sony.com>
1 parent 1056e36 commit 516061b

File tree

7 files changed

+184
-6
lines changed

7 files changed

+184
-6
lines changed

model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import model_compression_toolkit.target_platform_capabilities.schema.v1 as schema
1+
import model_compression_toolkit.target_platform_capabilities.schema.v2 as schema
22

33
OperatorSetNames = schema.OperatorSetNames
44
Signedness = schema.Signedness
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
import pprint
16+
from enum import Enum
17+
from typing import Dict, Any, Tuple, Optional
18+
19+
from pydantic import BaseModel, root_validator
20+
21+
from mct_quantizers import QuantizationMethod
22+
from model_compression_toolkit.constants import FLOAT_BITWIDTH
23+
from model_compression_toolkit.logger import Logger
24+
from model_compression_toolkit.target_platform_capabilities.schema.v1 import (
25+
Signedness,
26+
AttributeQuantizationConfig,
27+
OpQuantizationConfig,
28+
QuantizationConfigOptions,
29+
TargetPlatformModelComponent,
30+
OperatorsSetBase,
31+
OperatorsSet,
32+
OperatorSetGroup,
33+
Fusing)
34+
35+
36+
class OperatorSetNames(str, Enum):
37+
CONV = "Conv"
38+
DEPTHWISE_CONV = "DepthwiseConv2D"
39+
CONV_TRANSPOSE = "ConvTranspose"
40+
FULLY_CONNECTED = "FullyConnected"
41+
CONCATENATE = "Concatenate"
42+
STACK = "Stack"
43+
UNSTACK = "Unstack"
44+
GATHER = "Gather"
45+
EXPAND = "Expend"
46+
BATCH_NORM = "BatchNorm"
47+
L2NORM = "L2Norm"
48+
RELU = "ReLU"
49+
RELU6 = "ReLU6"
50+
LEAKY_RELU = "LeakyReLU"
51+
ELU = "Elu"
52+
HARD_TANH = "HardTanh"
53+
ADD = "Add"
54+
SUB = "Sub"
55+
MUL = "Mul"
56+
DIV = "Div"
57+
MIN = "Min"
58+
MAX = "Max"
59+
PRELU = "PReLU"
60+
ADD_BIAS = "AddBias"
61+
SWISH = "Swish"
62+
SIGMOID = "Sigmoid"
63+
SOFTMAX = "Softmax"
64+
LOG_SOFTMAX = "LogSoftmax"
65+
TANH = "Tanh"
66+
GELU = "Gelu"
67+
HARDSIGMOID = "HardSigmoid"
68+
HARDSWISH = "HardSwish"
69+
FLATTEN = "Flatten"
70+
GET_ITEM = "GetItem"
71+
RESHAPE = "Reshape"
72+
UNSQUEEZE = "Unsqueeze"
73+
SQUEEZE = "Squeeze"
74+
PERMUTE = "Permute"
75+
TRANSPOSE = "Transpose"
76+
DROPOUT = "Dropout"
77+
SPLIT_CHUNK = "SplitChunk"
78+
MAXPOOL = "MaxPool"
79+
AVGPOOL = "AvgPool"
80+
SIZE = "Size"
81+
SHAPE = "Shape"
82+
EQUAL = "Equal"
83+
ARGMAX = "ArgMax"
84+
TOPK = "TopK"
85+
FAKE_QUANT = "FakeQuant"
86+
COMBINED_NON_MAX_SUPPRESSION = "CombinedNonMaxSuppression"
87+
BOX_DECODE = "BoxDecode"
88+
ZERO_PADDING2D = "ZeroPadding2D"
89+
CAST = "Cast"
90+
RESIZE = "Resize"
91+
PAD = "Pad"
92+
FOLD = "Fold"
93+
STRIDED_SLICE = "StridedSlice"
94+
SSD_POST_PROCESS = "SSDPostProcess"
95+
96+
@classmethod
97+
def get_values(cls):
98+
return [v.value for v in cls]
99+
100+
101+
class TargetPlatformCapabilities(BaseModel):
102+
"""
103+
Represents the hardware configuration used for quantized model inference.
104+
105+
Attributes:
106+
default_qco (QuantizationConfigOptions): Default quantization configuration options for the model.
107+
operator_set (Optional[Tuple[OperatorsSet, ...]]): Tuple of operator sets within the model.
108+
fusing_patterns (Optional[Tuple[Fusing, ...]]): Tuple of fusing patterns for the model.
109+
tpc_minor_version (Optional[int]): Minor version of the Target Platform Configuration.
110+
tpc_patch_version (Optional[int]): Patch version of the Target Platform Configuration.
111+
tpc_platform_type (Optional[str]): Type of the platform for the Target Platform Configuration.
112+
add_metadata (bool): Flag to determine if metadata should be added.
113+
name (str): Name of the Target Platform Model.
114+
is_simd_padding (bool): Indicates if SIMD padding is applied.
115+
SCHEMA_VERSION (int): Version of the schema for the Target Platform Model.
116+
"""
117+
default_qco: QuantizationConfigOptions
118+
operator_set: Optional[Tuple[OperatorsSet, ...]]
119+
fusing_patterns: Optional[Tuple[Fusing, ...]]
120+
tpc_minor_version: Optional[int]
121+
tpc_patch_version: Optional[int]
122+
tpc_platform_type: Optional[str]
123+
add_metadata: bool = True
124+
name: Optional[str] = "default_tpc"
125+
is_simd_padding: bool = False
126+
127+
SCHEMA_VERSION: int = 2
128+
129+
class Config:
130+
frozen = True
131+
132+
@root_validator(allow_reuse=True)
133+
def validate_after_initialization(cls, values: Dict[str, Any]) -> Dict[str, Any]:
134+
"""
135+
Perform validation after the model has been instantiated.
136+
137+
Args:
138+
values (Dict[str, Any]): The instantiated target platform model.
139+
140+
Returns:
141+
Dict[str, Any]: The validated values.
142+
"""
143+
# Validate `default_qco`
144+
default_qco = values.get('default_qco')
145+
if len(default_qco.quantization_configurations) != 1:
146+
Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") # pragma: no cover
147+
148+
# Validate `operator_set` uniqueness
149+
operator_set = values.get('operator_set')
150+
if operator_set is not None:
151+
opsets_names = [
152+
op.name.value if isinstance(op.name, OperatorSetNames) else op.name
153+
for op in operator_set
154+
]
155+
if len(set(opsets_names)) != len(opsets_names):
156+
Logger.critical("Operator Sets must have unique names.") # pragma: no cover
157+
158+
return values
159+
160+
def get_info(self) -> Dict[str, Any]:
161+
"""
162+
Get a dictionary summarizing the TargetPlatformCapabilities properties.
163+
164+
Returns:
165+
Dict[str, Any]: Summary of the TargetPlatformCapabilities properties.
166+
"""
167+
return {
168+
"Model name": self.name,
169+
"Operators sets": [o.get_info() for o in self.operator_set] if self.operator_set else [],
170+
"Fusing patterns": [f.get_info() for f in self.fusing_patterns] if self.fusing_patterns else [],
171+
}
172+
173+
def show(self):
174+
"""
175+
Display the TargetPlatformCapabilities.
176+
"""
177+
pprint.pprint(self.get_info(), sort_dicts=False)

model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def __init__(self):
9393
OperatorSetNames.TOPK: [tf.nn.top_k],
9494
OperatorSetNames.FAKE_QUANT: [tf.quantization.fake_quant_with_min_max_vars],
9595
OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [tf.image.combined_non_max_suppression],
96+
OperatorSetNames.BOX_DECODE: [], # no such operator in keras
9697
OperatorSetNames.ZERO_PADDING2D: [ZeroPadding2D],
9798
OperatorSetNames.CAST: [tf.cast],
9899
OperatorSetNames.STRIDED_SLICE: [tf.strided_slice],

model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ def __init__(self):
9797
OperatorSetNames.L2NORM: [LayerFilterParams(torch.nn.functional.normalize,
9898
Eq('p', 2) | Eq('p', None))],
9999
OperatorSetNames.SSD_POST_PROCESS: [], # no such operator in pytorch
100-
OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [] # no such operator in pytorch
100+
OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [], # no such operator in pytorch
101+
OperatorSetNames.BOX_DECODE: [] # no such operator in pytorch
101102
}
102103

103104
pytorch_linear_attr_mapping = {KERNEL_ATTR: DefaultDict(default_value=PYTORCH_KERNEL),

tests/common_tests/helpers/tpcs_for_tests/v4/tpc.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
# ==============================================================================
1515
from typing import List, Tuple
1616

17-
import model_compression_toolkit as mct
18-
import model_compression_toolkit.target_platform_capabilities.schema.v1 as schema
17+
import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema
1918
from mct_quantizers import QuantizationMethod
2019
from model_compression_toolkit.constants import FLOAT_BITWIDTH
2120
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR, WEIGHTS_N_BITS, \

tests_pytest/_fw_tests_common_base/base_tpc_attach2fw_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def setup_method(self):
5151

5252
def test_attach2fw_init(self):
5353
# verify built-in opset to operator mapping structure
54-
assert len(self.attach2fw._opset2layer) == 57 # number of built-in operator sets
54+
assert len(self.attach2fw._opset2layer) == 58 # number of built-in operator sets
5555
assert all(opset in self.attach2fw._opset2layer for opset in list(schema.OperatorSetNames))
5656
assert all(isinstance(key, schema.OperatorSetNames) for key in self.attach2fw._opset2layer.keys())
5757
assert all(isinstance(value, list) for value in self.attach2fw._opset2layer.values())

tests_pytest/common_tests/unit_tests/core/quantization/quantization_params_selection/test_qarams_activations_computation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
get_activations_qparams
3434
)
3535
from model_compression_toolkit.target_platform_capabilities import Signedness, OpQuantizationConfig
36-
from model_compression_toolkit.target_platform_capabilities.schema.v1 import AttributeQuantizationConfig
36+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig
3737

3838

3939
class TestActivationQParams:

0 commit comments

Comments
 (0)