Skip to content

Commit bae7a45

Browse files
irenabirenab
authored andcommitted
separate configuring activation and weight quantizers
1 parent 646c566 commit bae7a45

File tree

5 files changed

+52
-155
lines changed

5 files changed

+52
-155
lines changed

model_compression_toolkit/core/common/framework_implementation.py

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,31 +13,31 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
from abc import ABC, abstractmethod
16-
from typing import Callable, Any, List, Tuple, Dict, Generator
16+
from typing import Callable, Any, List, Tuple, Generator, Type
1717

1818
import numpy as np
1919

2020
from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
21-
from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
2221
from model_compression_toolkit.core import common
2322
from model_compression_toolkit.core.common import BaseNode
24-
from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
2523
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
2624
from model_compression_toolkit.core.common.graph.base_graph import Graph
27-
from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianInfoService
28-
from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
25+
from model_compression_toolkit.core.common.hessian import HessianScoresRequest
2926
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
3027
from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
3128
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
3229
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
33-
from model_compression_toolkit.core.common.user_info import UserInformation
3430

3531

3632
class FrameworkImplementation(ABC):
3733
"""
3834
An abstract class with abstract methods that should be implemented when supporting a new
3935
framework in MCT.
4036
"""
37+
weights_quant_layer_cls: Type
38+
activation_quant_layer_cls: Type
39+
configurable_weights_quantizer_cls: Type
40+
configurable_activation_quantizer_cls: Type
4141

4242
@property
4343
def constants(self):
@@ -327,33 +327,6 @@ def get_substitutions_after_second_moment_correction(self, quant_config: Quantiz
327327
f'framework\'s get_substitutions_after_second_moment_correction '
328328
f'method.') # pragma: no cover
329329

330-
@abstractmethod
331-
def get_sensitivity_evaluator(self,
332-
graph: Graph,
333-
quant_config: MixedPrecisionQuantizationConfig,
334-
representative_data_gen: Callable,
335-
fw_info: FrameworkInfo,
336-
hessian_info_service: HessianInfoService = None,
337-
disable_activation_for_metric: bool = False) -> SensitivityEvaluation:
338-
"""
339-
Creates and returns an object which handles the computation of a sensitivity metric for a mixed-precision
340-
configuration (comparing to the float model).
341-
342-
Args:
343-
graph: Graph to build its float and mixed-precision models.
344-
quant_config: QuantizationConfig of how the model should be quantized.
345-
representative_data_gen: Dataset to use for retrieving images for the models inputs.
346-
fw_info: FrameworkInfo object with information about the specific framework's model.
347-
disable_activation_for_metric: Whether to disable activation quantization when computing the MP metric.
348-
hessian_info_service: HessianInfoService to fetch information based on Hessian-approximation.
349-
350-
Returns:
351-
A function that computes the metric.
352-
"""
353-
354-
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
355-
f'framework\'s get_sensitivity_evaluator method.') # pragma: no cover
356-
357330
def get_node_prior_info(self, node: BaseNode,
358331
fw_info: FrameworkInfo,
359332
graph: Graph) -> NodePriorInfo:

model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,39 +12,45 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
from typing import Any
15+
import typing
16+
from typing import Any, Optional
1617

18+
if typing.TYPE_CHECKING:
19+
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
1720

18-
def set_layer_to_bitwidth(quantization_layer: Any,
19-
bitwidth_idx: int,
20-
weights_quantizer_type: type,
21-
activation_quantizer_type: type,
22-
weights_quant_layer_type: type,
23-
activation_quant_layer_type: type):
21+
22+
def set_activation_quant_layer_to_bitwidth(quantization_layer: Any,
23+
bitwidth_idx: Optional[int],
24+
fw_impl: 'FrameworkImplementation'):
2425
"""
25-
Configures a layer's configurable quantizer to work with a different bit-width.
26+
Configures a layer's configurable activation quantizer to work with a different bit-width.
2627
The bit-width_idx is the index of the actual quantizer the quantizer object in the quantization_layer wraps/holds.
2728
2829
Args:
2930
quantization_layer: Layer to change its bit-width.
30-
bitwidth_idx: Index of the bit-width the layer should work with.
31-
weights_quantizer_type: A class of weights quantizer with configurable bitwidth options.
32-
activation_quantizer_type: A class of activation quantizer with configurable bitwidth options.
33-
weights_quant_layer_type: A class of a weights layer wrapper.
34-
activation_quant_layer_type: A class of an activation quantization holder.
31+
bitwidth_idx: Index of the bit-width the layer should work with, or None to disable quantization.
32+
fw_impl: framework implementation object.
3533
"""
34+
assert isinstance(quantization_layer, fw_impl.activation_quant_layer_cls)
35+
assert isinstance(quantization_layer.activation_holder_quantizer, fw_impl.configurable_activation_quantizer_cls)
36+
quantization_layer.activation_holder_quantizer.set_active_activation_quantizer(bitwidth_idx)
37+
3638

37-
if isinstance(quantization_layer, weights_quant_layer_type):
38-
for _, quantizer in quantization_layer.weights_quantizers.items():
39-
if isinstance(quantizer, weights_quantizer_type):
40-
# Setting bitwidth only for configurable layers. There might be wrapped layers that aren't configurable,
41-
# for instance, if only activations are quantized with mixed precision and weights are quantized with
42-
# fixed precision
43-
quantizer.set_weights_bit_width_index(bitwidth_idx)
39+
def set_weights_quant_layer_to_bitwidth(quantization_layer: Any,
40+
bitwidth_idx: Optional[int],
41+
fw_impl: 'FrameworkImplementation'):
42+
"""
43+
Configures a layer's configurable weights quantizer to work with a different bit-width.
44+
The bit-width_idx is the index of the actual quantizer the quantizer object in the quantization_layer wraps/holds.
4445
45-
if isinstance(quantization_layer, activation_quant_layer_type):
46-
if isinstance(quantization_layer.activation_holder_quantizer, activation_quantizer_type):
47-
# Setting bitwidth only for configurable layers. There might be activation layers that isn't configurable,
48-
# for instance, if only weights are quantized with mixed precision and activation are quantized with
49-
# fixed precision
50-
quantization_layer.activation_holder_quantizer.set_active_activation_quantizer(bitwidth_idx)
46+
Args:
47+
quantization_layer: Layer to change its bit-width.
48+
bitwidth_idx: Index of the bit-width the layer should work with, or None to disable quantization.
49+
fw_impl: framework implementation object.
50+
"""
51+
assert isinstance(quantization_layer, fw_impl.weights_quant_layer_cls)
52+
configurable_quantizers = [q for q in quantization_layer.weights_quantizers.values()
53+
if isinstance(q, fw_impl.configurable_weights_quantizer_cls)]
54+
assert configurable_quantizers
55+
for quantizer in configurable_quantizers:
56+
quantizer.set_weights_bit_width_index(bitwidth_idx)

model_compression_toolkit/core/keras/keras_implementation.py

Lines changed: 8 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
from functools import partial
16-
from typing import List, Any, Tuple, Callable, Dict, Union, Generator
16+
from typing import List, Any, Tuple, Callable, Union, Generator
1717

1818
import numpy as np
1919
import tensorflow as tf
@@ -22,7 +22,7 @@
2222

2323
from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
2424
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
25-
from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianMode, HessianInfoService
25+
from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianMode
2626
from model_compression_toolkit.core.keras.data_util import data_gen_to_dataloader
2727
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.remove_identity import RemoveIdentity
2828
from model_compression_toolkit.core.keras.hessian.activation_hessian_scores_calculator_keras import \
@@ -35,8 +35,6 @@
3535
from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizer import \
3636
get_weights_quantizer_for_node, get_activations_quantizer_for_node
3737
from model_compression_toolkit.logger import Logger
38-
from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
39-
from model_compression_toolkit.core.common.mixed_precision.set_layer_to_bitwidth import set_layer_to_bitwidth
4038
from model_compression_toolkit.core.common.similarity_analyzer import compute_kl_divergence, compute_cs, compute_mse
4139
from model_compression_toolkit.core.keras.constants import ACTIVATION, SOFTMAX, SIGMOID, ARGMAX, LAYER_NAME, \
4240
COMBINED_NMS
@@ -61,7 +59,7 @@
6159
from keras.layers import Dense, Activation, Conv2D, DepthwiseConv2D, Conv2DTranspose, Concatenate, Add # pragma: no cover
6260
from keras.layers.core import TFOpLambda # pragma: no cover
6361

64-
from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo, CoreConfig, MixedPrecisionQuantizationConfig
62+
from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo, CoreConfig
6563
from model_compression_toolkit.core import common
6664
from model_compression_toolkit.core.common import Graph, BaseNode
6765
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
@@ -95,7 +93,7 @@
9593
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.scale_equalization import \
9694
ScaleEqualization, ScaleEqualizationWithPad, ScaleEqualizationMidActivation, ScaleEqualizationMidActivationWithPad
9795
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.separableconv_decomposition import \
98-
SeparableConvDecomposition, DEPTH_MULTIPLIER
96+
SeparableConvDecomposition
9997
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.shift_negative_activation import \
10098
keras_apply_shift_negative_correction
10199
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.dwconv_to_conv import DwconvToConv
@@ -110,9 +108,10 @@ class KerasImplementation(FrameworkImplementation):
110108
"""
111109
A class with implemented methods to support optimizing Keras models.
112110
"""
113-
114-
def __init__(self):
115-
super().__init__()
111+
weights_quant_layer_cls = KerasQuantizationWrapper
112+
activation_quant_layer_cls = KerasActivationQuantizationHolder
113+
configurable_weights_quantizer_cls = ConfigurableWeightsQuantizer
114+
configurable_activation_quantizer_cls = ConfigurableActivationQuantizer
116115

117116
@property
118117
def constants(self):
@@ -401,42 +400,6 @@ def get_substitutions_after_second_moment_correction(self, quant_config: Quantiz
401400
substitutions_list.append(keras_batchnorm_refusing())
402401
return substitutions_list
403402

404-
def get_sensitivity_evaluator(self,
405-
graph: Graph,
406-
quant_config: MixedPrecisionQuantizationConfig,
407-
representative_data_gen: Callable,
408-
fw_info: FrameworkInfo,
409-
disable_activation_for_metric: bool = False,
410-
hessian_info_service: HessianInfoService = None) -> SensitivityEvaluation:
411-
"""
412-
Creates and returns an object which handles the computation of a sensitivity metric for a mixed-precision
413-
configuration (comparing to the float model).
414-
415-
Args:
416-
graph: Graph to build its float and mixed-precision models.
417-
quant_config: QuantizationConfig of how the model should be quantized.
418-
representative_data_gen: Dataset to use for retrieving images for the models inputs.
419-
fw_info: FrameworkInfo object with information about the specific framework's model.
420-
disable_activation_for_metric: Whether to disable activation quantization when computing the MP metric.
421-
hessian_info_service: HessianScoresService to fetch scores based on a Hessian-approximation for the float model.
422-
423-
Returns:
424-
A SensitivityEvaluation object.
425-
"""
426-
427-
return SensitivityEvaluation(graph=graph,
428-
quant_config=quant_config,
429-
representative_data_gen=representative_data_gen,
430-
fw_info=fw_info,
431-
fw_impl=self,
432-
set_layer_to_bitwidth=partial(set_layer_to_bitwidth,
433-
weights_quantizer_type=ConfigurableWeightsQuantizer,
434-
activation_quantizer_type=ConfigurableActivationQuantizer,
435-
weights_quant_layer_type=KerasQuantizationWrapper,
436-
activation_quant_layer_type=KerasActivationQuantizationHolder),
437-
disable_activation_for_metric=disable_activation_for_metric,
438-
hessian_info_service=hessian_info_service)
439-
440403
def get_node_prior_info(self,
441404
node: BaseNode,
442405
fw_info: FrameworkInfo,

model_compression_toolkit/core/pytorch/pytorch_implementation.py

Lines changed: 6 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,12 @@
2626

2727
import model_compression_toolkit.core.pytorch.constants as pytorch_constants
2828
from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
29-
from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo, CoreConfig, MixedPrecisionQuantizationConfig
29+
from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo, CoreConfig
3030
from model_compression_toolkit.core import common
3131
from model_compression_toolkit.core.common import Graph, BaseNode
3232
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
3333
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
34-
from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianMode, HessianInfoService
35-
from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
36-
from model_compression_toolkit.core.common.mixed_precision.set_layer_to_bitwidth import set_layer_to_bitwidth
34+
from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianMode
3735
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
3836
from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
3937
from model_compression_toolkit.core.common.similarity_analyzer import compute_mse, compute_kl_divergence, compute_cs
@@ -112,6 +110,10 @@ class PytorchImplementation(FrameworkImplementation):
112110
"""
113111
A class with implemented methods to support optimizing Pytorch models.
114112
"""
113+
weights_quant_layer_cls = PytorchQuantizationWrapper,
114+
activation_quant_layer_cls = PytorchActivationQuantizationHolder
115+
configurable_weights_quantizer_cls = ConfigurableWeightsQuantizer
116+
configurable_activation_quantizer_cls = ConfigurableActivationQuantizer
115117

116118
def __init__(self):
117119
super().__init__()
@@ -397,43 +399,6 @@ def get_substitutions_after_second_moment_correction(self, quant_config: Quantiz
397399
substitutions_list.append(pytorch_batchnorm_refusing())
398400
return substitutions_list
399401

400-
def get_sensitivity_evaluator(self,
401-
graph: Graph,
402-
quant_config: MixedPrecisionQuantizationConfig,
403-
representative_data_gen: Callable,
404-
fw_info: FrameworkInfo,
405-
disable_activation_for_metric: bool = False,
406-
hessian_info_service: HessianInfoService = None
407-
) -> SensitivityEvaluation:
408-
"""
409-
Creates and returns an object which handles the computation of a sensitivity metric for a mixed-precision
410-
configuration (comparing to the float model).
411-
412-
Args:
413-
graph: Graph to build its float and mixed-precision models.
414-
quant_config: QuantizationConfig of how the model should be quantized.
415-
representative_data_gen: Dataset to use for retrieving images for the models inputs.
416-
fw_info: FrameworkInfo object with information about the specific framework's model.
417-
disable_activation_for_metric: Whether to disable activation quantization when computing the MP metric.
418-
hessian_info_service: HessianScoresService to fetch approximations of the hessian scores for the float model.
419-
420-
Returns:
421-
A SensitivityEvaluation object.
422-
"""
423-
424-
return SensitivityEvaluation(graph=graph,
425-
quant_config=quant_config,
426-
representative_data_gen=representative_data_gen,
427-
fw_info=fw_info,
428-
fw_impl=self,
429-
set_layer_to_bitwidth=partial(set_layer_to_bitwidth,
430-
weights_quantizer_type=ConfigurableWeightsQuantizer,
431-
activation_quantizer_type=ConfigurableActivationQuantizer,
432-
weights_quant_layer_type=PytorchQuantizationWrapper,
433-
activation_quant_layer_type=PytorchActivationQuantizationHolder),
434-
disable_activation_for_metric=disable_activation_for_metric,
435-
hessian_info_service=hessian_info_service)
436-
437402
def get_node_prior_info(self,
438403
node: BaseNode,
439404
fw_info: FrameworkInfo,

0 commit comments

Comments
 (0)