Skip to content

Commit 5c036ad

Browse files
add e2e test
1 parent 4113537 commit 5c036ad

1 file changed

Lines changed: 86 additions & 0 deletions

File tree

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright 2025 Sony Semiconductor Solutions, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
from typing import Iterator, List
16+
import torch
17+
import torch.nn as nn
18+
import model_compression_toolkit as mct
19+
from model_compression_toolkit.target_platform_capabilities import AttributeQuantizationConfig
20+
import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema
21+
from mct_quantizers import QuantizationMethod
22+
23+
24+
class TakeModel(nn.Module):
25+
26+
def __init__(self, indices):
27+
super().__init__()
28+
self.conv = nn.Conv2d(3, 16, kernel_size=3)
29+
self.relu = nn.ReLU()
30+
self.indices = torch.as_tensor(indices, dtype=torch.long)
31+
32+
def forward(self, x):
33+
x = self.relu(self.conv(x))
34+
output = torch.take(x, self.indices)
35+
return output
36+
37+
38+
def get_representative_dataset(n_iter=1):
39+
40+
def representative_dataset() -> Iterator[List]:
41+
for _ in range(n_iter):
42+
yield [torch.randn(1, 3, 32, 32)]
43+
return representative_dataset
44+
45+
46+
def get_edgemdt_tpc_v6():
47+
48+
default_config = schema.OpQuantizationConfig(
49+
default_weight_attr_config=AttributeQuantizationConfig(),
50+
attr_weights_configs_mapping={},
51+
activation_quantization_method=QuantizationMethod.POWER_OF_TWO,
52+
activation_n_bits=8,
53+
supported_input_activation_n_bits=8,
54+
enable_activation_quantization=True,
55+
quantization_preserving=False,
56+
fixed_scale=None,
57+
fixed_zero_point=None,
58+
simd_size=32,
59+
signedness=schema.Signedness.AUTO)
60+
61+
default_configuration_options = schema.QuantizationConfigOptions(quantization_configurations=tuple([default_config]))
62+
dim_manipulation_config = (default_configuration_options.clone_and_edit(enable_activation_quantization=False,
63+
quantization_preserving=True,
64+
supported_input_activation_n_bits=(8, 16))
65+
.clone_and_edit_weight_attribute(enable_weights_quantization=False))
66+
operator_set = []
67+
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.TAKE, qc_options=dim_manipulation_config))
68+
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.CONV, qc_options=default_configuration_options))
69+
operator_set.append(schema.OperatorsSet(name=schema.OperatorSetNames.RELU, qc_options=default_configuration_options))
70+
71+
tpc = schema.TargetPlatformCapabilities(
72+
default_qco=default_configuration_options,
73+
operator_set=tuple(operator_set))
74+
return tpc
75+
76+
77+
def test_take():
78+
79+
model = TakeModel(indices=[0, 100])
80+
tpc = get_edgemdt_tpc_v6() # TPC equivalent to edgemdt-tpc v6.0
81+
quantized_model, _ = mct.ptq.pytorch_post_training_quantization(model,
82+
get_representative_dataset(n_iter=1),
83+
target_resource_utilization=None,
84+
core_config=mct.core.CoreConfig(),
85+
target_platform_capabilities=tpc)
86+
assert hasattr(quantized_model, 'take')

0 commit comments

Comments
 (0)