1313# limitations under the License.
1414# ==============================================================================
1515from functools import partial
16- from typing import List , Any , Tuple , Callable , Dict , Union , Generator
16+ from typing import List , Any , Tuple , Callable , Union , Generator
1717
1818import numpy as np
1919import tensorflow as tf
2222
2323from model_compression_toolkit .constants import HESSIAN_NUM_ITERATIONS
2424from model_compression_toolkit .core .common .graph .functional_node import FunctionalNode
25- from model_compression_toolkit .core .common .hessian import HessianScoresRequest , HessianMode , HessianInfoService
25+ from model_compression_toolkit .core .common .hessian import HessianScoresRequest , HessianMode
2626from model_compression_toolkit .core .keras .data_util import data_gen_to_dataloader
2727from model_compression_toolkit .core .keras .graph_substitutions .substitutions .remove_identity import RemoveIdentity
2828from model_compression_toolkit .core .keras .hessian .activation_hessian_scores_calculator_keras import \
3535from model_compression_toolkit .exporter .model_wrapper .keras .builder .node_to_quantizer import \
3636 get_weights_quantizer_for_node , get_activations_quantizer_for_node
3737from model_compression_toolkit .logger import Logger
38- from model_compression_toolkit .core .common .mixed_precision .sensitivity_evaluation import SensitivityEvaluation
39- from model_compression_toolkit .core .common .mixed_precision .set_layer_to_bitwidth import set_layer_to_bitwidth
4038from model_compression_toolkit .core .common .similarity_analyzer import compute_kl_divergence , compute_cs , compute_mse
4139from model_compression_toolkit .core .keras .constants import ACTIVATION , SOFTMAX , SIGMOID , ARGMAX , LAYER_NAME , \
4240 COMBINED_NMS
6159 from keras .layers import Dense , Activation , Conv2D , DepthwiseConv2D , Conv2DTranspose , Concatenate , Add # pragma: no cover
6260 from keras .layers .core import TFOpLambda # pragma: no cover
6361
64- from model_compression_toolkit .core import QuantizationConfig , FrameworkInfo , CoreConfig , MixedPrecisionQuantizationConfig
62+ from model_compression_toolkit .core import QuantizationConfig , FrameworkInfo , CoreConfig
6563from model_compression_toolkit .core import common
6664from model_compression_toolkit .core .common import Graph , BaseNode
6765from model_compression_toolkit .core .common .framework_implementation import FrameworkImplementation
9593from model_compression_toolkit .core .keras .graph_substitutions .substitutions .scale_equalization import \
9694 ScaleEqualization , ScaleEqualizationWithPad , ScaleEqualizationMidActivation , ScaleEqualizationMidActivationWithPad
9795from model_compression_toolkit .core .keras .graph_substitutions .substitutions .separableconv_decomposition import \
98- SeparableConvDecomposition , DEPTH_MULTIPLIER
96+ SeparableConvDecomposition
9997from model_compression_toolkit .core .keras .graph_substitutions .substitutions .shift_negative_activation import \
10098 keras_apply_shift_negative_correction
10199from model_compression_toolkit .core .keras .graph_substitutions .substitutions .dwconv_to_conv import DwconvToConv
@@ -110,9 +108,10 @@ class KerasImplementation(FrameworkImplementation):
110108 """
111109 A class with implemented methods to support optimizing Keras models.
112110 """
113-
114- def __init__ (self ):
115- super ().__init__ ()
111+ weights_quant_layer_cls = KerasQuantizationWrapper
112+ activation_quant_layer_cls = KerasActivationQuantizationHolder
113+ configurable_weights_quantizer_cls = ConfigurableWeightsQuantizer
114+ configurable_activation_quantizer_cls = ConfigurableActivationQuantizer
116115
117116 @property
118117 def constants (self ):
@@ -401,42 +400,6 @@ def get_substitutions_after_second_moment_correction(self, quant_config: Quantiz
401400 substitutions_list .append (keras_batchnorm_refusing ())
402401 return substitutions_list
403402
404- def get_sensitivity_evaluator (self ,
405- graph : Graph ,
406- quant_config : MixedPrecisionQuantizationConfig ,
407- representative_data_gen : Callable ,
408- fw_info : FrameworkInfo ,
409- disable_activation_for_metric : bool = False ,
410- hessian_info_service : HessianInfoService = None ) -> SensitivityEvaluation :
411- """
412- Creates and returns an object which handles the computation of a sensitivity metric for a mixed-precision
413- configuration (comparing to the float model).
414-
415- Args:
416- graph: Graph to build its float and mixed-precision models.
417- quant_config: QuantizationConfig of how the model should be quantized.
418- representative_data_gen: Dataset to use for retrieving images for the models inputs.
419- fw_info: FrameworkInfo object with information about the specific framework's model.
420- disable_activation_for_metric: Whether to disable activation quantization when computing the MP metric.
421- hessian_info_service: HessianScoresService to fetch scores based on a Hessian-approximation for the float model.
422-
423- Returns:
424- A SensitivityEvaluation object.
425- """
426-
427- return SensitivityEvaluation (graph = graph ,
428- quant_config = quant_config ,
429- representative_data_gen = representative_data_gen ,
430- fw_info = fw_info ,
431- fw_impl = self ,
432- set_layer_to_bitwidth = partial (set_layer_to_bitwidth ,
433- weights_quantizer_type = ConfigurableWeightsQuantizer ,
434- activation_quantizer_type = ConfigurableActivationQuantizer ,
435- weights_quant_layer_type = KerasQuantizationWrapper ,
436- activation_quant_layer_type = KerasActivationQuantizationHolder ),
437- disable_activation_for_metric = disable_activation_for_metric ,
438- hessian_info_service = hessian_info_service )
439-
440403 def get_node_prior_info (self ,
441404 node : BaseNode ,
442405 fw_info : FrameworkInfo ,
0 commit comments