Skip to content

Commit 2a2310d

Browse files
Back to schema version 1 (#1401)
1 parent 7a08405 commit 2a2310d

File tree

5 files changed

+4
-70
lines changed

5 files changed

+4
-70
lines changed

model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import model_compression_toolkit.target_platform_capabilities.schema.v2 as schema
1+
import model_compression_toolkit.target_platform_capabilities.schema.v1 as schema
22

33
OperatorSetNames = schema.OperatorSetNames
44
Signedness = schema.Signedness

model_compression_toolkit/target_platform_capabilities/schema/v2.py

Lines changed: 2 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -30,72 +30,8 @@
3030
OperatorsSetBase,
3131
OperatorsSet,
3232
OperatorSetGroup,
33-
Fusing)
34-
35-
36-
class OperatorSetNames(str, Enum):
37-
CONV = "Conv"
38-
DEPTHWISE_CONV = "DepthwiseConv2D"
39-
CONV_TRANSPOSE = "ConvTranspose"
40-
FULLY_CONNECTED = "FullyConnected"
41-
CONCATENATE = "Concatenate"
42-
STACK = "Stack"
43-
UNSTACK = "Unstack"
44-
GATHER = "Gather"
45-
EXPAND = "Expend"
46-
BATCH_NORM = "BatchNorm"
47-
L2NORM = "L2Norm"
48-
RELU = "ReLU"
49-
RELU6 = "ReLU6"
50-
LEAKY_RELU = "LeakyReLU"
51-
ELU = "Elu"
52-
HARD_TANH = "HardTanh"
53-
ADD = "Add"
54-
SUB = "Sub"
55-
MUL = "Mul"
56-
DIV = "Div"
57-
MIN = "Min"
58-
MAX = "Max"
59-
PRELU = "PReLU"
60-
ADD_BIAS = "AddBias"
61-
SWISH = "Swish"
62-
SIGMOID = "Sigmoid"
63-
SOFTMAX = "Softmax"
64-
LOG_SOFTMAX = "LogSoftmax"
65-
TANH = "Tanh"
66-
GELU = "Gelu"
67-
HARDSIGMOID = "HardSigmoid"
68-
HARDSWISH = "HardSwish"
69-
FLATTEN = "Flatten"
70-
GET_ITEM = "GetItem"
71-
RESHAPE = "Reshape"
72-
UNSQUEEZE = "Unsqueeze"
73-
SQUEEZE = "Squeeze"
74-
PERMUTE = "Permute"
75-
TRANSPOSE = "Transpose"
76-
DROPOUT = "Dropout"
77-
SPLIT_CHUNK = "SplitChunk"
78-
MAXPOOL = "MaxPool"
79-
AVGPOOL = "AvgPool"
80-
SIZE = "Size"
81-
SHAPE = "Shape"
82-
EQUAL = "Equal"
83-
ARGMAX = "ArgMax"
84-
TOPK = "TopK"
85-
FAKE_QUANT = "FakeQuant"
86-
COMBINED_NON_MAX_SUPPRESSION = "CombinedNonMaxSuppression"
87-
BOX_DECODE = "BoxDecode"
88-
ZERO_PADDING2D = "ZeroPadding2D"
89-
CAST = "Cast"
90-
RESIZE = "Resize"
91-
PAD = "Pad"
92-
FOLD = "Fold"
93-
STRIDED_SLICE = "StridedSlice"
94-
SSD_POST_PROCESS = "SSDPostProcess"
95-
96-
@classmethod
97-
def get_values(cls):
98-
return [v.value for v in cls]
33+
Fusing,
34+
OperatorSetNames)
9935

10036

10137
class TargetPlatformCapabilities(BaseModel):

model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ def __init__(self):
9393
OperatorSetNames.TOPK: [tf.nn.top_k],
9494
OperatorSetNames.FAKE_QUANT: [tf.quantization.fake_quant_with_min_max_vars],
9595
OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [tf.image.combined_non_max_suppression],
96-
OperatorSetNames.BOX_DECODE: [], # no such operator in keras
9796
OperatorSetNames.ZERO_PADDING2D: [ZeroPadding2D],
9897
OperatorSetNames.CAST: [tf.cast],
9998
OperatorSetNames.STRIDED_SLICE: [tf.strided_slice],

model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ def __init__(self):
9898
Eq('p', 2) | Eq('p', None))],
9999
OperatorSetNames.SSD_POST_PROCESS: [], # no such operator in pytorch
100100
OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [], # no such operator in pytorch
101-
OperatorSetNames.BOX_DECODE: [] # no such operator in pytorch
102101
}
103102

104103
pytorch_linear_attr_mapping = {KERNEL_ATTR: DefaultDict(default_value=PYTORCH_KERNEL),

tests_pytest/_fw_tests_common_base/base_tpc_attach2fw_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def setup_method(self):
5151

5252
def test_attach2fw_init(self):
5353
# verify built-in opset to operator mapping structure
54-
assert len(self.attach2fw._opset2layer) == 58 # number of built-in operator sets
54+
assert len(self.attach2fw._opset2layer) == 57 # number of built-in operator sets
5555
assert all(opset in self.attach2fw._opset2layer for opset in list(schema.OperatorSetNames))
5656
assert all(isinstance(key, schema.OperatorSetNames) for key in self.attach2fw._opset2layer.keys())
5757
assert all(isinstance(value, list) for value in self.attach2fw._opset2layer.values())

0 commit comments

Comments
 (0)