Skip to content

Commit 77e3d11

Browse files
authored
Add an option to dynamically disable configurable quantizers, separate configuration api for activation and weights quantizers (#1438)
* add an option to disable configurable quantizers * separate configuring activation and weight quantizers
1 parent 5bfd07d commit 77e3d11

File tree

20 files changed

+307
-513
lines changed

20 files changed

+307
-513
lines changed

model_compression_toolkit/core/common/framework_implementation.py

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,31 +13,31 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
from abc import ABC, abstractmethod
16-
from typing import Callable, Any, List, Tuple, Dict, Generator
16+
from typing import Callable, Any, List, Tuple, Generator, Type
1717

1818
import numpy as np
1919

2020
from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
21-
from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
2221
from model_compression_toolkit.core import common
2322
from model_compression_toolkit.core.common import BaseNode
24-
from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
2523
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
2624
from model_compression_toolkit.core.common.graph.base_graph import Graph
27-
from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianInfoService
28-
from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
25+
from model_compression_toolkit.core.common.hessian import HessianScoresRequest
2926
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
3027
from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
3128
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
3229
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
33-
from model_compression_toolkit.core.common.user_info import UserInformation
3430

3531

3632
class FrameworkImplementation(ABC):
3733
"""
3834
An abstract class with abstract methods that should be implemented when supporting a new
3935
framework in MCT.
4036
"""
37+
weights_quant_layer_cls: Type
38+
activation_quant_layer_cls: Type
39+
configurable_weights_quantizer_cls: Type
40+
configurable_activation_quantizer_cls: Type
4141

4242
@property
4343
def constants(self):
@@ -327,33 +327,6 @@ def get_substitutions_after_second_moment_correction(self, quant_config: Quantiz
327327
f'framework\'s get_substitutions_after_second_moment_correction '
328328
f'method.') # pragma: no cover
329329

330-
@abstractmethod
331-
def get_sensitivity_evaluator(self,
332-
graph: Graph,
333-
quant_config: MixedPrecisionQuantizationConfig,
334-
representative_data_gen: Callable,
335-
fw_info: FrameworkInfo,
336-
hessian_info_service: HessianInfoService = None,
337-
disable_activation_for_metric: bool = False) -> SensitivityEvaluation:
338-
"""
339-
Creates and returns an object which handles the computation of a sensitivity metric for a mixed-precision
340-
configuration (comparing to the float model).
341-
342-
Args:
343-
graph: Graph to build its float and mixed-precision models.
344-
quant_config: QuantizationConfig of how the model should be quantized.
345-
representative_data_gen: Dataset to use for retrieving images for the models inputs.
346-
fw_info: FrameworkInfo object with information about the specific framework's model.
347-
disable_activation_for_metric: Whether to disable activation quantization when computing the MP metric.
348-
hessian_info_service: HessianInfoService to fetch information based on Hessian-approximation.
349-
350-
Returns:
351-
A function that computes the metric.
352-
"""
353-
354-
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
355-
f'framework\'s get_sensitivity_evaluator method.') # pragma: no cover
356-
357330
def get_node_prior_info(self, node: BaseNode,
358331
fw_info: FrameworkInfo,
359332
graph: Graph) -> NodePriorInfo:

model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
MixedPrecisionSearchManager
2626
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
2727
ResourceUtilization
28+
from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
2829
from model_compression_toolkit.core.common.mixed_precision.solution_refinement_procedure import \
2930
greedy_solution_refinement_procedure
3031

@@ -78,11 +79,12 @@ def search_bit_width(graph: Graph,
7879

7980
# Set Sensitivity Evaluator for MP search. It should always work with the original MP graph,
8081
# even if a virtual graph was created (and is used only for BOPS utilization computation purposes)
81-
se = fw_impl.get_sensitivity_evaluator(
82+
se = SensitivityEvaluation(
8283
graph,
8384
mp_config,
8485
representative_data_gen=representative_data_gen,
8586
fw_info=fw_info,
87+
fw_impl=fw_impl,
8688
disable_activation_for_metric=disable_activation_for_metric,
8789
hessian_info_service=hessian_info_service)
8890

model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from model_compression_toolkit.constants import AXIS
2121
from model_compression_toolkit.core import FrameworkInfo, MixedPrecisionQuantizationConfig
2222
from model_compression_toolkit.core.common import Graph, BaseNode
23+
from model_compression_toolkit.core.common.mixed_precision.set_layer_to_bitwidth import \
24+
set_activation_quant_layer_to_bitwidth, set_weights_quant_layer_to_bitwidth
2325
from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode
2426
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
2527
from model_compression_toolkit.core.common.similarity_analyzer import compute_kl_divergence
@@ -41,7 +43,6 @@ def __init__(self,
4143
representative_data_gen: Callable,
4244
fw_info: FrameworkInfo,
4345
fw_impl: Any,
44-
set_layer_to_bitwidth: Callable,
4546
disable_activation_for_metric: bool = False,
4647
hessian_info_service: HessianInfoService = None
4748
):
@@ -63,8 +64,6 @@ def __init__(self,
6364
quant_config: MP Quantization configuration for how the graph should be quantized.
6465
representative_data_gen: Dataset used for getting batches for inference.
6566
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
66-
set_layer_to_bitwidth: A fw-dependent function that allows to configure a configurable MP model
67-
with a specific bit-width configuration.
6867
disable_activation_for_metric: Whether to disable activation quantization when computing the MP metric.
6968
hessian_info_service: HessianInfoService to fetch Hessian approximation information.
7069
@@ -74,7 +73,6 @@ def __init__(self,
7473
self.representative_data_gen = representative_data_gen
7574
self.fw_info = fw_info
7675
self.fw_impl = fw_impl
77-
self.set_layer_to_bitwidth = set_layer_to_bitwidth
7876
self.disable_activation_for_metric = disable_activation_for_metric
7977
if self.quant_config.use_hessian_based_scores:
8078
if not isinstance(hessian_info_service, HessianInfoService):
@@ -307,7 +305,13 @@ def _configure_node_bitwidth(self,
307305
f"Matching layers for node {node_name} not found in the mixed precision model configuration.") # pragma: no cover
308306

309307
for current_layer in layers_to_config:
310-
self.set_layer_to_bitwidth(current_layer, mp_model_configuration[node_idx_to_configure])
308+
if isinstance(current_layer, self.fw_impl.activation_quant_layer_cls):
309+
set_activation_quant_layer_to_bitwidth(current_layer, mp_model_configuration[node_idx_to_configure],
310+
self.fw_impl)
311+
else:
312+
assert isinstance(current_layer, self.fw_impl.weights_quant_layer_cls)
313+
set_weights_quant_layer_to_bitwidth(current_layer, mp_model_configuration[node_idx_to_configure],
314+
self.fw_impl)
311315

312316
def _compute_points_distance(self,
313317
baseline_tensors: List[Any],

model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,39 +12,46 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
from typing import Any
15+
from typing import Any, Optional, TYPE_CHECKING
1616

17+
if TYPE_CHECKING:
18+
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
1719

18-
def set_layer_to_bitwidth(quantization_layer: Any,
19-
bitwidth_idx: int,
20-
weights_quantizer_type: type,
21-
activation_quantizer_type: type,
22-
weights_quant_layer_type: type,
23-
activation_quant_layer_type: type):
20+
21+
def set_activation_quant_layer_to_bitwidth(quantization_layer: Any,
22+
bitwidth_idx: Optional[int],
23+
fw_impl: 'FrameworkImplementation'):
2424
"""
25-
Configures a layer's configurable quantizer to work with a different bit-width.
25+
Configures a layer's configurable activation quantizer to work with a different bit-width.
2626
The bit-width_idx is the index of the actual quantizer the quantizer object in the quantization_layer wraps/holds.
2727
2828
Args:
2929
quantization_layer: Layer to change its bit-width.
30-
bitwidth_idx: Index of the bit-width the layer should work with.
31-
weights_quantizer_type: A class of weights quantizer with configurable bitwidth options.
32-
activation_quantizer_type: A class of activation quantizer with configurable bitwidth options.
33-
weights_quant_layer_type: A class of a weights layer wrapper.
34-
activation_quant_layer_type: A class of an activation quantization holder.
30+
bitwidth_idx: Index of the bit-width the layer should work with, or None to disable quantization.
31+
fw_impl: framework implementation object.
3532
"""
33+
assert isinstance(quantization_layer, fw_impl.activation_quant_layer_cls)
34+
# TODO irena enable after float mp
35+
# assert isinstance(quantization_layer.activation_holder_quantizer, fw_impl.configurable_activation_quantizer_cls)
36+
quantization_layer.activation_holder_quantizer.set_active_activation_quantizer(bitwidth_idx)
37+
3638

37-
if isinstance(quantization_layer, weights_quant_layer_type):
38-
for _, quantizer in quantization_layer.weights_quantizers.items():
39-
if isinstance(quantizer, weights_quantizer_type):
40-
# Setting bitwidth only for configurable layers. There might be wrapped layers that aren't configurable,
41-
# for instance, if only activations are quantized with mixed precision and weights are quantized with
42-
# fixed precision
43-
quantizer.set_weights_bit_width_index(bitwidth_idx)
39+
def set_weights_quant_layer_to_bitwidth(quantization_layer: Any,
40+
bitwidth_idx: Optional[int],
41+
fw_impl: 'FrameworkImplementation'):
42+
"""
43+
Configures a layer's configurable weights quantizer to work with a different bit-width.
44+
The bit-width_idx is the index of the actual quantizer the quantizer object in the quantization_layer wraps/holds.
4445
45-
if isinstance(quantization_layer, activation_quant_layer_type):
46-
if isinstance(quantization_layer.activation_holder_quantizer, activation_quantizer_type):
47-
# Setting bitwidth only for configurable layers. There might be activation layers that isn't configurable,
48-
# for instance, if only weights are quantized with mixed precision and activation are quantized with
49-
# fixed precision
50-
quantization_layer.activation_holder_quantizer.set_active_activation_quantizer(bitwidth_idx)
46+
Args:
47+
quantization_layer: Layer to change its bit-width.
48+
bitwidth_idx: Index of the bit-width the layer should work with, or None to disable quantization.
49+
fw_impl: framework implementation object.
50+
"""
51+
assert isinstance(quantization_layer, fw_impl.weights_quant_layer_cls)
52+
configurable_quantizers = [q for q in quantization_layer.weights_quantizers.values()
53+
if isinstance(q, fw_impl.configurable_weights_quantizer_cls)]
54+
# TODO irena enable after float mp
55+
# assert configurable_quantizers
56+
for quantizer in configurable_quantizers:
57+
quantizer.set_weights_bit_width_index(bitwidth_idx)

model_compression_toolkit/core/keras/keras_implementation.py

Lines changed: 8 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
from functools import partial
16-
from typing import List, Any, Tuple, Callable, Dict, Union, Generator
16+
from typing import List, Any, Tuple, Callable, Union, Generator
1717

1818
import numpy as np
1919
import tensorflow as tf
@@ -22,7 +22,7 @@
2222

2323
from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
2424
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
25-
from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianMode, HessianInfoService
25+
from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianMode
2626
from model_compression_toolkit.core.keras.data_util import data_gen_to_dataloader
2727
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.remove_identity import RemoveIdentity
2828
from model_compression_toolkit.core.keras.hessian.activation_hessian_scores_calculator_keras import \
@@ -35,8 +35,6 @@
3535
from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizer import \
3636
get_weights_quantizer_for_node, get_activations_quantizer_for_node
3737
from model_compression_toolkit.logger import Logger
38-
from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
39-
from model_compression_toolkit.core.common.mixed_precision.set_layer_to_bitwidth import set_layer_to_bitwidth
4038
from model_compression_toolkit.core.common.similarity_analyzer import compute_kl_divergence, compute_cs, compute_mse
4139
from model_compression_toolkit.core.keras.constants import ACTIVATION, SOFTMAX, SIGMOID, ARGMAX, LAYER_NAME, \
4240
COMBINED_NMS
@@ -61,7 +59,7 @@
6159
from keras.layers import Dense, Activation, Conv2D, DepthwiseConv2D, Conv2DTranspose, Concatenate, Add # pragma: no cover
6260
from keras.layers.core import TFOpLambda # pragma: no cover
6361

64-
from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo, CoreConfig, MixedPrecisionQuantizationConfig
62+
from model_compression_toolkit.core import QuantizationConfig, FrameworkInfo, CoreConfig
6563
from model_compression_toolkit.core import common
6664
from model_compression_toolkit.core.common import Graph, BaseNode
6765
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
@@ -95,7 +93,7 @@
9593
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.scale_equalization import \
9694
ScaleEqualization, ScaleEqualizationWithPad, ScaleEqualizationMidActivation, ScaleEqualizationMidActivationWithPad
9795
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.separableconv_decomposition import \
98-
SeparableConvDecomposition, DEPTH_MULTIPLIER
96+
SeparableConvDecomposition
9997
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.shift_negative_activation import \
10098
keras_apply_shift_negative_correction
10199
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.dwconv_to_conv import DwconvToConv
@@ -110,9 +108,10 @@ class KerasImplementation(FrameworkImplementation):
110108
"""
111109
A class with implemented methods to support optimizing Keras models.
112110
"""
113-
114-
def __init__(self):
115-
super().__init__()
111+
weights_quant_layer_cls = KerasQuantizationWrapper
112+
activation_quant_layer_cls = KerasActivationQuantizationHolder
113+
configurable_weights_quantizer_cls = ConfigurableWeightsQuantizer
114+
configurable_activation_quantizer_cls = ConfigurableActivationQuantizer
116115

117116
@property
118117
def constants(self):
@@ -401,42 +400,6 @@ def get_substitutions_after_second_moment_correction(self, quant_config: Quantiz
401400
substitutions_list.append(keras_batchnorm_refusing())
402401
return substitutions_list
403402

404-
def get_sensitivity_evaluator(self,
405-
graph: Graph,
406-
quant_config: MixedPrecisionQuantizationConfig,
407-
representative_data_gen: Callable,
408-
fw_info: FrameworkInfo,
409-
disable_activation_for_metric: bool = False,
410-
hessian_info_service: HessianInfoService = None) -> SensitivityEvaluation:
411-
"""
412-
Creates and returns an object which handles the computation of a sensitivity metric for a mixed-precision
413-
configuration (comparing to the float model).
414-
415-
Args:
416-
graph: Graph to build its float and mixed-precision models.
417-
quant_config: QuantizationConfig of how the model should be quantized.
418-
representative_data_gen: Dataset to use for retrieving images for the models inputs.
419-
fw_info: FrameworkInfo object with information about the specific framework's model.
420-
disable_activation_for_metric: Whether to disable activation quantization when computing the MP metric.
421-
hessian_info_service: HessianScoresService to fetch scores based on a Hessian-approximation for the float model.
422-
423-
Returns:
424-
A SensitivityEvaluation object.
425-
"""
426-
427-
return SensitivityEvaluation(graph=graph,
428-
quant_config=quant_config,
429-
representative_data_gen=representative_data_gen,
430-
fw_info=fw_info,
431-
fw_impl=self,
432-
set_layer_to_bitwidth=partial(set_layer_to_bitwidth,
433-
weights_quantizer_type=ConfigurableWeightsQuantizer,
434-
activation_quantizer_type=ConfigurableActivationQuantizer,
435-
weights_quant_layer_type=KerasQuantizationWrapper,
436-
activation_quant_layer_type=KerasActivationQuantizationHolder),
437-
disable_activation_for_metric=disable_activation_for_metric,
438-
hessian_info_service=hessian_info_service)
439-
440403
def get_node_prior_info(self,
441404
node: BaseNode,
442405
fw_info: FrameworkInfo,

0 commit comments

Comments
 (0)