1+ # Copyright 2025 Sony Semiconductor Solutions, Inc. All rights reserved.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+ # ==============================================================================
15+ from typing import Iterator , List
16+ import torch
17+ import torch .nn as nn
18+ import model_compression_toolkit as mct
19+ from model_compression_toolkit .target_platform_capabilities import AttributeQuantizationConfig
20+ import model_compression_toolkit .target_platform_capabilities .schema .mct_current_schema as schema
21+ from mct_quantizers import QuantizationMethod
22+
23+
24+ class TakeModel (nn .Module ):
25+
26+ def __init__ (self , indices ):
27+ super ().__init__ ()
28+ self .conv = nn .Conv2d (3 , 16 , kernel_size = 3 )
29+ self .relu = nn .ReLU ()
30+ self .indices = torch .as_tensor (indices , dtype = torch .long )
31+
32+ def forward (self , x ):
33+ x = self .relu (self .conv (x ))
34+ output = torch .take (x , self .indices )
35+ return output
36+
37+
38+ def get_representative_dataset (n_iter = 1 ):
39+
40+ def representative_dataset () -> Iterator [List ]:
41+ for _ in range (n_iter ):
42+ yield [torch .randn (1 , 3 , 32 , 32 )]
43+ return representative_dataset
44+
45+
46+ def get_edgemdt_tpc_v6 ():
47+
48+ default_config = schema .OpQuantizationConfig (
49+ default_weight_attr_config = AttributeQuantizationConfig (),
50+ attr_weights_configs_mapping = {},
51+ activation_quantization_method = QuantizationMethod .POWER_OF_TWO ,
52+ activation_n_bits = 8 ,
53+ supported_input_activation_n_bits = 8 ,
54+ enable_activation_quantization = True ,
55+ quantization_preserving = False ,
56+ fixed_scale = None ,
57+ fixed_zero_point = None ,
58+ simd_size = 32 ,
59+ signedness = schema .Signedness .AUTO )
60+
61+ default_configuration_options = schema .QuantizationConfigOptions (quantization_configurations = tuple ([default_config ]))
62+ dim_manipulation_config = (default_configuration_options .clone_and_edit (enable_activation_quantization = False ,
63+ quantization_preserving = True ,
64+ supported_input_activation_n_bits = (8 , 16 ))
65+ .clone_and_edit_weight_attribute (enable_weights_quantization = False ))
66+ operator_set = []
67+ operator_set .append (schema .OperatorsSet (name = schema .OperatorSetNames .TAKE , qc_options = dim_manipulation_config ))
68+ operator_set .append (schema .OperatorsSet (name = schema .OperatorSetNames .CONV , qc_options = default_configuration_options ))
69+ operator_set .append (schema .OperatorsSet (name = schema .OperatorSetNames .RELU , qc_options = default_configuration_options ))
70+
71+ tpc = schema .TargetPlatformCapabilities (
72+ default_qco = default_configuration_options ,
73+ operator_set = tuple (operator_set ))
74+ return tpc
75+
76+
77+ def test_take ():
78+
79+ model = TakeModel (indices = [0 , 100 ])
80+ tpc = get_edgemdt_tpc_v6 () # TPC equivalent to edgemdt-tpc v6.0
81+ quantized_model , _ = mct .ptq .pytorch_post_training_quantization (model ,
82+ get_representative_dataset (n_iter = 1 ),
83+ target_resource_utilization = None ,
84+ core_config = mct .core .CoreConfig (),
85+ target_platform_capabilities = tpc )
86+ assert hasattr (quantized_model , 'take' )
87+ assert not hasattr (quantized_model , 'take_activation_holder_quantizer' )
0 commit comments