2323 NodeActivationQuantizationConfig , NodeWeightsQuantizationConfig
2424from model_compression_toolkit .target_platform_capabilities import OpQuantizationConfig
2525from model_compression_toolkit .core import QuantizationConfig , QuantizationErrorMethod
26+ from model_compression_toolkit .core .common .hessian .hessian_info_service import HessianInfoService
2627from model_compression_toolkit .target_platform_capabilities .targetplatform2framework .attach2pytorch import \
2728 AttachTpcToPytorch
2829import model_compression_toolkit .target_platform_capabilities .schema .mct_current_schema as schema
2930from model_compression_toolkit .target_platform_capabilities .schema .mct_current_schema import Signedness , \
3031 AttributeQuantizationConfig
31- from model_compression_toolkit .core .pytorch .default_framework_info import DEFAULT_PYTORCH_INFO
32+ from model_compression_toolkit .core .pytorch .default_framework_info import PyTorchInfo
33+ from model_compression_toolkit .core .common .framework_info import set_fw_info , get_fw_info
34+
3235from model_compression_toolkit .core .pytorch .pytorch_implementation import PytorchImplementation
3336from model_compression_toolkit .core .common .collectors .statistics_collector import StatsCollector
3437from model_compression_toolkit .target_platform_capabilities .constants import KERNEL_ATTR , WEIGHTS_N_BITS
3538from mct_quantizers import QuantizationMethod
3639
40+ from model_compression_toolkit .core .common .framework_info import ChannelAxisMapping
41+
3742class TestCalculateQuantizationParams :
3843 def get_op_qco (self ):
3944 # define a default quantization config for all non-specified weights attributes.
@@ -145,12 +150,11 @@ def _create_node_weights_op_cfg(self,
145150
146151 def get_test_graph (self , qem : QuantizationErrorMethod ):
147152 float_model = self .get_float_model ()
148- fw_info = DEFAULT_PYTORCH_INFO
153+ set_fw_info ( PyTorchInfo )
149154
150155 fw_impl = PytorchImplementation ()
151156 graph = fw_impl .model_reader (float_model ,
152157 self .representative_data_gen )
153- graph .set_fw_info (fw_info )
154158
155159 quantization_config = QuantizationConfig (weights_error_method = qem )
156160
@@ -165,24 +169,23 @@ def get_test_graph(self, qem: QuantizationErrorMethod):
165169
166170 graph .node_to_out_stats_collector = dict ()
167171 for id , n in enumerate (graph .nodes ):
168- n .prior_info = fw_impl .get_node_prior_info (node = n , fw_info = fw_info , graph = graph )
172+ n .prior_info = fw_impl .get_node_prior_info (node = n , graph = graph )
169173 n .candidates_quantization_cfg = []
170174 candidate_qc_a = CandidateNodeQuantizationConfig (
171175 activation_quantization_cfg = NodeActivationQuantizationConfig (qc = quantization_config , op_cfg = op_cfg ,
172176 activation_quantization_fn = None ,
173177 activation_quantization_params_fn = None ),
174178 weights_quantization_cfg = NodeWeightsQuantizationConfig (qc = quantization_config , op_cfg = op_cfg ,
175- weights_channels_axis = (0 , 1 ),
179+ weights_channels_axis = ChannelAxisMapping (0 , 1 ),
176180 node_attrs_list = ['weight' , 'bias' ])
177181 )
178182 if n .name in ['conv3' ]:
179183 candidate_qc_a .activation_quantization_cfg .quant_mode = ActivationQuantizationMode .FLN_QUANT
180- candidate_qc_a .activation_quantization_cfg .activation_n_bits = 16 # set 16bit for FLN node for test.
181184 else :
182185 candidate_qc_a .activation_quantization_cfg .quant_mode = ActivationQuantizationMode .QUANT
183186 n .candidates_quantization_cfg .append (candidate_qc_a )
184187
185- graph .node_to_out_stats_collector [n ] = StatsCollector (init_min_value = 0.0 , init_max_value = 1.0 , out_channel_axis = fw_info .out_channel_axis_mapping .get (n .type ))
188+ graph .node_to_out_stats_collector [n ] = StatsCollector (init_min_value = 0.0 , init_max_value = 1.0 , out_channel_axis = get_fw_info () .out_channel_axis_mapping .get (n .type ))
186189 graph .node_to_out_stats_collector [n ].hc ._n_bins = 3
187190 if n .name in ['conv1' ]:
188191 graph .node_to_out_stats_collector [n ].hc ._bins = np .array ([0.4 , 0.8 , 1.2 ])
0 commit comments