Skip to content

Commit 48df9c7

Browse files
irenabirenab
authored andcommitted
fixes for api change
1 parent cd6518a commit 48df9c7

File tree

10 files changed

+44
-334
lines changed

10 files changed

+44
-334
lines changed

model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
MixedPrecisionSearchManager
2626
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
2727
ResourceUtilization
28+
from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
2829
from model_compression_toolkit.core.common.mixed_precision.solution_refinement_procedure import \
2930
greedy_solution_refinement_procedure
3031

@@ -78,11 +79,12 @@ def search_bit_width(graph: Graph,
7879

7980
# Set Sensitivity Evaluator for MP search. It should always work with the original MP graph,
8081
# even if a virtual graph was created (and is used only for BOPS utilization computation purposes)
81-
se = fw_impl.get_sensitivity_evaluator(
82+
se = SensitivityEvaluation(
8283
graph,
8384
mp_config,
8485
representative_data_gen=representative_data_gen,
8586
fw_info=fw_info,
87+
fw_impl=fw_impl,
8688
disable_activation_for_metric=disable_activation_for_metric,
8789
hessian_info_service=hessian_info_service)
8890

model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from model_compression_toolkit.constants import AXIS
2121
from model_compression_toolkit.core import FrameworkInfo, MixedPrecisionQuantizationConfig
2222
from model_compression_toolkit.core.common import Graph, BaseNode
23+
from model_compression_toolkit.core.common.mixed_precision.set_layer_to_bitwidth import \
24+
set_activation_quant_layer_to_bitwidth, set_weights_quant_layer_to_bitwidth
2325
from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode
2426
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
2527
from model_compression_toolkit.core.common.similarity_analyzer import compute_kl_divergence
@@ -41,7 +43,6 @@ def __init__(self,
4143
representative_data_gen: Callable,
4244
fw_info: FrameworkInfo,
4345
fw_impl: Any,
44-
set_layer_to_bitwidth: Callable,
4546
disable_activation_for_metric: bool = False,
4647
hessian_info_service: HessianInfoService = None
4748
):
@@ -63,8 +64,6 @@ def __init__(self,
6364
quant_config: MP Quantization configuration for how the graph should be quantized.
6465
representative_data_gen: Dataset used for getting batches for inference.
6566
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
66-
set_layer_to_bitwidth: A fw-dependent function that allows to configure a configurable MP model
67-
with a specific bit-width configuration.
6867
disable_activation_for_metric: Whether to disable activation quantization when computing the MP metric.
6968
hessian_info_service: HessianInfoService to fetch Hessian approximation information.
7069
@@ -74,7 +73,6 @@ def __init__(self,
7473
self.representative_data_gen = representative_data_gen
7574
self.fw_info = fw_info
7675
self.fw_impl = fw_impl
77-
self.set_layer_to_bitwidth = set_layer_to_bitwidth
7876
self.disable_activation_for_metric = disable_activation_for_metric
7977
if self.quant_config.use_hessian_based_scores:
8078
if not isinstance(hessian_info_service, HessianInfoService):
@@ -307,7 +305,13 @@ def _configure_node_bitwidth(self,
307305
f"Matching layers for node {node_name} not found in the mixed precision model configuration.") # pragma: no cover
308306

309307
for current_layer in layers_to_config:
310-
self.set_layer_to_bitwidth(current_layer, mp_model_configuration[node_idx_to_configure])
308+
if isinstance(current_layer, self.fw_impl.activation_quant_layer_cls):
309+
set_activation_quant_layer_to_bitwidth(current_layer, mp_model_configuration[node_idx_to_configure],
310+
self.fw_impl)
311+
else:
312+
assert isinstance(current_layer, self.fw_impl.weights_quant_layer_cls)
313+
set_weights_quant_layer_to_bitwidth(current_layer, mp_model_configuration[node_idx_to_configure],
314+
self.fw_impl)
311315

312316
def _compute_points_distance(self,
313317
baseline_tensors: List[Any],

model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
import typing
16-
from typing import Any, Optional
15+
from typing import Any, Optional, TYPE_CHECKING
1716

18-
if typing.TYPE_CHECKING:
17+
if TYPE_CHECKING:
1918
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
2019

2120

@@ -32,7 +31,8 @@ def set_activation_quant_layer_to_bitwidth(quantization_layer: Any,
3231
fw_impl: framework implementation object.
3332
"""
3433
assert isinstance(quantization_layer, fw_impl.activation_quant_layer_cls)
35-
assert isinstance(quantization_layer.activation_holder_quantizer, fw_impl.configurable_activation_quantizer_cls)
34+
# TODO irena enable after float mp
35+
# assert isinstance(quantization_layer.activation_holder_quantizer, fw_impl.configurable_activation_quantizer_cls)
3636
quantization_layer.activation_holder_quantizer.set_active_activation_quantizer(bitwidth_idx)
3737

3838

@@ -51,6 +51,7 @@ def set_weights_quant_layer_to_bitwidth(quantization_layer: Any,
5151
assert isinstance(quantization_layer, fw_impl.weights_quant_layer_cls)
5252
configurable_quantizers = [q for q in quantization_layer.weights_quantizers.values()
5353
if isinstance(q, fw_impl.configurable_weights_quantizer_cls)]
54-
assert configurable_quantizers
54+
# TODO irena enable after float mp
55+
# assert configurable_quantizers
5556
for quantizer in configurable_quantizers:
5657
quantizer.set_weights_bit_width_index(bitwidth_idx)

tests/keras_tests/function_tests/test_sensitivity_eval_non_suppoerted_output.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919

2020
from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
21+
from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
2122
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
2223
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
2324
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
@@ -93,12 +94,13 @@ def verify_test_for_model(self, model):
9394

9495
# Reducing the default number of samples for Mixed precision Hessian approximation
9596
# to allow quick execution of the test
96-
se = keras_impl.get_sensitivity_evaluator(graph,
97-
MixedPrecisionQuantizationConfig(use_hessian_based_scores=True,
98-
num_of_images=2),
99-
representative_dataset,
100-
DEFAULT_KERAS_INFO,
101-
hessian_info_service=hessian_info_service)
97+
se = SensitivityEvaluation(graph,
98+
MixedPrecisionQuantizationConfig(use_hessian_based_scores=True,
99+
num_of_images=2),
100+
representative_dataset,
101+
DEFAULT_KERAS_INFO,
102+
keras_impl,
103+
hessian_info_service=hessian_info_service)
102104

103105
def test_not_supported_output_argmax(self):
104106
model = argmax_output_model((8, 8, 3))

tests/keras_tests/function_tests/test_set_layer_to_bitwidth.py

Lines changed: 0 additions & 143 deletions
This file was deleted.

tests/keras_tests/non_parallel_tests/test_lp_search_bitwidth.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
MixedPrecisionQuantizationConfig
2626
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_facade import search_bit_width, \
2727
BitWidthSearchMethod
28+
from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
2829
from model_compression_toolkit.core.common.model_collector import ModelCollector
2930
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
3031
from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_computation import \
@@ -96,10 +97,11 @@ def representative_data_gen():
9697

9798
calculate_quantization_params(graph, fw_impl=keras_impl, repr_data_gen_fn=representative_data_gen)
9899

99-
keras_impl.get_sensitivity_evaluator(graph,
100-
core_config.mixed_precision_config,
101-
representative_data_gen,
102-
fw_info=fw_info)
100+
SensitivityEvaluation(graph,
101+
core_config.mixed_precision_config,
102+
representative_data_gen,
103+
fw_info=fw_info,
104+
fw_impl=keras_impl)
103105

104106
cfg = search_bit_width(graph=graph,
105107
fw_info=DEFAULT_KERAS_INFO,

0 commit comments

Comments
 (0)