Skip to content

Commit 874ccb5

Browse files
Add new TPCv6.0 (torch.take) (#1570)
1 parent 55d1c4b commit 874ccb5

File tree

5 files changed

+92
-2
lines changed

5 files changed

+92
-2
lines changed

model_compression_toolkit/target_platform_capabilities/schema/v2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class OperatorSetNames(str, Enum):
9393
EXP = "Exp"
9494
SIN = "Sin"
9595
COS = "Cos"
96+
TAKE = "Take"
9697

9798
@classmethod
9899
def get_values(cls):

model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ def __init__(self):
103103
OperatorSetNames.SSD_POST_PROCESS: [SSDPostProcess],
104104
OperatorSetNames.EXP: [tf.math.exp],
105105
OperatorSetNames.SIN: [tf.math.sin],
106-
OperatorSetNames.COS: [tf.math.cos]
106+
OperatorSetNames.COS: [tf.math.cos],
107+
OperatorSetNames.TAKE: [], # no such operator in tensorflow
107108
}
108109

109110
self._opset2attr_mapping = {

model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def __init__(self):
102102
OperatorSetNames.EXP: [torch.exp],
103103
OperatorSetNames.SIN: [torch.sin],
104104
OperatorSetNames.COS: [torch.cos],
105+
OperatorSetNames.TAKE: [torch.take],
105106
}
106107

107108
pytorch_linear_attr_mapping = {KERNEL_ATTR: DefaultDict(default_value=PYTORCH_KERNEL),

tests_pytest/_fw_tests_common_base/base_tpc_attach2fw_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def setup_method(self):
5151

5252
def test_attach2fw_init(self):
5353
# verify built-in opset to operator mapping structure
54-
assert len(self.attach2fw._opset2layer) == 60 # number of built-in operator sets
54+
assert len(self.attach2fw._opset2layer) == 61 # number of built-in operator sets
5555
assert all(opset in self.attach2fw._opset2layer for opset in list(schema.OperatorSetNames))
5656
assert all(isinstance(key, schema.OperatorSetNames) for key in self.attach2fw._opset2layer.keys())
5757
assert all(isinstance(value, list) for value in self.attach2fw._opset2layer.values())
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright 2025 Sony Semiconductor Solutions, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
from typing import Iterator, List
16+
import torch
17+
import torch.nn as nn
18+
import model_compression_toolkit as mct
19+
from model_compression_toolkit.target_platform_capabilities import AttributeQuantizationConfig
20+
import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema
21+
from mct_quantizers import QuantizationMethod
22+
23+
24+
class TakeModel(nn.Module):
25+
26+
def __init__(self, indices):
27+
super().__init__()
28+
self.conv = nn.Conv2d(3, 16, kernel_size=3)
29+
self.relu = nn.ReLU()
30+
self.indices = torch.as_tensor(indices, dtype=torch.long)
31+
32+
def forward(self, x):
33+
x = self.relu(self.conv(x))
34+
output = torch.take(x, self.indices)
35+
return output
36+
37+
38+
def get_representative_dataset(n_iter=1):
39+
40+
def representative_dataset() -> Iterator[List]:
41+
for _ in range(n_iter):
42+
yield [torch.randn(1, 3, 32, 32)]
43+
return representative_dataset
44+
45+
46+
def get_edgemdt_tpc_v6():
47+
48+
default_config = schema.OpQuantizationConfig(
49+
default_weight_attr_config=AttributeQuantizationConfig(),
50+
attr_weights_configs_mapping={},
51+
activation_quantization_method=QuantizationMethod.POWER_OF_TWO,
52+
activation_n_bits=8,
53+
supported_input_activation_n_bits=8,
54+
enable_activation_quantization=True,
55+
quantization_preserving=False,
56+
fixed_scale=None,
57+
fixed_zero_point=None,
58+
simd_size=32,
59+
signedness=schema.Signedness.AUTO)
60+
61+
default_configuration_options = schema.QuantizationConfigOptions(quantization_configurations=tuple([default_config]))
62+
dim_manipulation_config = (default_configuration_options.clone_and_edit(enable_activation_quantization=False,
63+
quantization_preserving=True,
64+
supported_input_activation_n_bits=(8, 16))
65+
.clone_and_edit_weight_attribute(enable_weights_quantization=False))
66+
operator_set = []
67+
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.TAKE, qc_options=dim_manipulation_config))
68+
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.CONV, qc_options=default_configuration_options))
69+
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.RELU, qc_options=default_configuration_options))
70+
71+
tpc = schema.TargetPlatformCapabilities(
72+
default_qco=default_configuration_options,
73+
operator_set=tuple(operator_set))
74+
return tpc
75+
76+
77+
def test_take():
78+
79+
model = TakeModel(indices=[0, 100])
80+
tpc = get_edgemdt_tpc_v6() # TPC equivalent to edgemdt-tpc v6.0
81+
quantized_model, _ = mct.ptq.pytorch_post_training_quantization(model,
82+
get_representative_dataset(n_iter=1),
83+
target_resource_utilization=None,
84+
core_config=mct.core.CoreConfig(),
85+
target_platform_capabilities=tpc)
86+
assert hasattr(quantized_model, 'take')
87+
assert not hasattr(quantized_model, 'take_activation_holder_quantizer')

0 commit comments

Comments
 (0)