Skip to content
18 changes: 14 additions & 4 deletions model_compression_toolkit/core/common/graph/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,14 +706,24 @@ def update_fused_nodes(self, fusion: List[Any]):
"""
self.fused_nodes.append(fusion)

def is_single_activation_cfg(self):
def has_any_configurable_activation(self) -> bool:
"""
Checks whether all nodes in the graph that have activation quantization are quantized with the same bit-width.
Checks whether any node in the graph has a configurable activation quantization.

Returns: True if all quantization config candidates of all nodes have the same activation quantization bit-width.
Returns:
Whether any node in the graph has a configurable activation quantization.
"""
return any([n.has_configurable_activation() for n in self.nodes])

def has_any_configurable_weights(self):
"""
Checks whether any node in the graph has any configurable weights quantization.

Returns:
Whether any node in the graph has any configurable weights quantization.
"""
return all([n.is_all_activation_candidates_equal() for n in self.nodes])

return any([n.has_any_configurable_weight() for n in self.nodes])

def replace_node(self, node_to_replace: BaseNode, new_node: BaseNode):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import List, Set, Dict, Optional, Tuple, Any
from typing import List, Set, Dict, Tuple

import numpy as np

from model_compression_toolkit.core import FrameworkInfo
from model_compression_toolkit.core.common import Graph, BaseNode
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
RUTarget
Expand All @@ -36,42 +36,46 @@ def __init__(self, graph: Graph, fw_info: FrameworkInfo, fw_impl: FrameworkImple
self.fw_impl = fw_impl
self.ru_calculator = ResourceUtilizationCalculator(graph, fw_impl, fw_info)

def compute_utilization(self, ru_targets: Set[RUTarget], mp_cfg: Optional[List[int]]) -> Dict[RUTarget, np.ndarray]:
def compute_utilization(self, ru_targets: Set[RUTarget], mp_cfg: List[int]) -> Dict[RUTarget, np.ndarray]:
"""
Compute utilization of requested targets for a specific configuration in the format expected by LP problem
formulation namely a vector of ru values for relevant memory elements (nodes or cuts) in a constant order
(between calls).
Compute utilization of requested targets for a specific configuration:
for weights and bops - total utilization,
for activations and total - utilization per cut.

Args:
ru_targets: resource utilization targets to compute.
mp_cfg: a list of candidates indices for configurable layers.

Returns:
Dict of the computed utilization per target.
Dict of the computed utilization per target, as 1d vector.
"""

ru = {}
act_qcs, w_qcs = self.get_quantization_candidates(mp_cfg) if mp_cfg else (None, None)
if RUTarget.WEIGHTS in ru_targets:
wu = self._weights_utilization(w_qcs)
ru[RUTarget.WEIGHTS] = np.array(list(wu.values()))

if RUTarget.ACTIVATION in ru_targets:
au = self._activation_utilization(act_qcs)
ru[RUTarget.ACTIVATION] = np.array(list(au.values()))

if RUTarget.BOPS in ru_targets:
ru[RUTarget.BOPS] = self._bops_utilization(act_qcs=act_qcs, w_qcs=w_qcs)

if RUTarget.TOTAL in ru_targets:
raise ValueError('Total target should be computed based on weights and activations targets.')

assert len(ru) == len(ru_targets), (f'Mismatch between the number of computed and requested metrics.'
f'Requested {ru_targets}')
return ru
act_qcs, w_qcs = self.get_quantization_candidates(mp_cfg)

ru, detailed_ru = self.ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized,
BitwidthMode.QCustom,
act_qcs=act_qcs,
w_qcs=w_qcs,
ru_targets=ru_targets,
allow_unused_qcs=True,
return_detailed=True)

ru_dict = {k: np.array([v]) for k, v in ru.get_resource_utilization_dict(restricted_only=True).items()}
# For activation and total we need utilization per cut, as different mp configurations might result in
# different cuts to be maximal.
for target in [RUTarget.ACTIVATION, RUTarget.TOTAL]:
if target in ru_dict:
ru_dict[target] = np.array(list(detailed_ru[target].values()))

assert all(v.ndim == 1 for v in ru_dict.values())
if RUTarget.ACTIVATION in ru_targets and RUTarget.TOTAL in ru_targets:
assert ru_dict[RUTarget.ACTIVATION].shape == ru_dict[RUTarget.TOTAL].shape

assert len(ru_dict) == len(ru_targets), (f'Mismatch between the number of computed and requested metrics.'
f'Requested {ru_targets}')
return ru_dict

def get_quantization_candidates(self, mp_cfg) \
-> Tuple[Dict[BaseNode, NodeActivationQuantizationConfig], Dict[BaseNode, NodeWeightsQuantizationConfig]]:
-> Tuple[Dict[str, NodeActivationQuantizationConfig], Dict[str, NodeWeightsQuantizationConfig]]:
"""
Retrieve quantization candidates objects for weights and activations from the configuration list.

Expand All @@ -87,71 +91,3 @@ def get_quantization_candidates(self, mp_cfg) \
act_qcs = {n.name: cfg.activation_quantization_cfg for n, cfg in node_qcs.items()}
w_qcs = {n.name: cfg.weights_quantization_cfg for n, cfg in node_qcs.items()}
return act_qcs, w_qcs

def _weights_utilization(self, w_qcs: Optional[Dict[BaseNode, NodeWeightsQuantizationConfig]]) -> Dict[BaseNode, float]:
"""
Compute weights utilization for configurable weights if configuration is passed,
or for non-configurable nodes otherwise.

Args:
w_qcs: nodes quantization configuration to compute, or None.

Returns:
Weight utilization per node.
"""
if w_qcs:
target_criterion = TargetInclusionCriterion.QConfigurable
bitwidth_mode = BitwidthMode.QCustom
else:
target_criterion = TargetInclusionCriterion.QNonConfigurable
bitwidth_mode = BitwidthMode.QDefaultSP

_, nodes_util, _ = self.ru_calculator.compute_weights_utilization(target_criterion=target_criterion,
bitwidth_mode=bitwidth_mode,
w_qcs=w_qcs)
nodes_util = {n: u.bytes for n, u in nodes_util.items()}
return nodes_util

def _activation_utilization(self, act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]]) \
-> Optional[Dict[Any, float]]:
"""
Compute activation utilization using MaxCut for all quantized nodes if configuration is passed.

Args:
act_qcs: nodes activation configuration or None.

Returns:
Activation utilization per cut, or empty dict if no configuration was passed.
"""
# Maxcut activation utilization is computed for all quantized nodes, so non-configurable memory is already
# covered by the computation of configurable activations.
if not act_qcs:
return {}

_, cuts_util, *_ = self.ru_calculator.compute_activation_utilization_by_cut(
TargetInclusionCriterion.AnyQuantized, bitwidth_mode=BitwidthMode.QCustom, act_qcs=act_qcs)
cuts_util = {c: u.bytes for c, u in cuts_util.items()}
return cuts_util

def _bops_utilization(self,
act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]],
w_qcs: Optional[Dict[BaseNode, NodeWeightsQuantizationConfig]]) -> np.ndarray:
"""
Computes a resource utilization vector with the respective bit-operations (BOPS) count
according to the given mixed-precision configuration.

Args:
act_qcs: nodes activation configuration or None.
w_qcs: nodes quantization configuration to compute, or None.
Either both are provided, or both are None.

Returns:
A vector of node's BOPS count.
"""
assert [act_qcs, w_qcs].count(None) in [0, 2], 'act_qcs and w_qcs should both be provided or both be None.'
if act_qcs is None:
return np.array([])

_, detailed_bops = self.ru_calculator.compute_bops(TargetInclusionCriterion.Any, BitwidthMode.QCustom,
act_qcs=act_qcs, w_qcs=w_qcs)
return np.array(list(detailed_bops.values()))
Original file line number Diff line number Diff line change
Expand Up @@ -13,37 +13,27 @@
# limitations under the License.
# ==============================================================================

import copy
from enum import Enum
import numpy as np
from typing import List, Callable, Dict
from typing import List, Callable

from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common.hessian import HessianInfoService
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization, RUTarget
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_manager import MixedPrecisionSearchManager
from model_compression_toolkit.core.common.mixed_precision.search_methods.linear_programming import \
mp_integer_programming_search
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
from model_compression_toolkit.core.common.hessian import HessianInfoService
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_manager import \
MixedPrecisionSearchManager
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
ResourceUtilization
from model_compression_toolkit.core.common.mixed_precision.solution_refinement_procedure import \
greedy_solution_refinement_procedure
from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute
from model_compression_toolkit.logger import Logger


class BitWidthSearchMethod(Enum):
# When adding a new search_methods MP configuration method, these enum and factory dictionary
# should be updated with it's kind and a search_method implementation.
INTEGER_PROGRAMMING = 0


search_methods = {
BitWidthSearchMethod.INTEGER_PROGRAMMING: mp_integer_programming_search}


def search_bit_width(graph_to_search_cfg: Graph,
def search_bit_width(graph: Graph,
fw_info: FrameworkInfo,
fw_impl: FrameworkImplementation,
target_resource_utilization: ResourceUtilization,
Expand All @@ -60,7 +50,7 @@ def search_bit_width(graph_to_search_cfg: Graph,
target_resource_utilization have to be passed. If it was not passed, the facade is not supposed to get here by now.

Args:
graph_to_search_cfg: Graph to search a MP configuration for.
graph: Graph to search a MP configuration for.
fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize).
fw_impl: FrameworkImplementation object with specific framework methods implementation.
target_resource_utilization: Target Resource Utilization to bound our feasible solution space s.t the configuration does not violate it.
Expand All @@ -75,51 +65,36 @@ def search_bit_width(graph_to_search_cfg: Graph,
bit-width index on the node).

"""

# target_resource_utilization have to be passed. If it was not passed, the facade is not supposed to get here by now.
if target_resource_utilization is None:
Logger.critical("Target ResourceUtilization is required for the bit-width search method's configuration.") # pragma: no cover

# Set graph for MP search
graph = copy.deepcopy(graph_to_search_cfg) # Copy graph before searching
if target_resource_utilization.bops_restricted():
# TODO: we only need the virtual graph is both activations and weights are configurable
# Since Bit-operations count target resource utilization is set, we need to reconstruct the graph for the MP search
graph = substitute(graph, fw_impl.get_substitutions_virtual_weights_activation_coupling())
assert target_resource_utilization.is_any_restricted()

# If we only run weights compression with MP than no need to consider activation quantization when computing the
# MP metric (it adds noise to the computation)
tru = target_resource_utilization
weight_only_restricted = tru.weight_restricted() and not (tru.activation_restricted() or
tru.total_mem_restricted() or
tru.bops_restricted())
disable_activation_for_metric = weight_only_restricted or graph_to_search_cfg.is_single_activation_cfg()
disable_activation_for_metric = weight_only_restricted or not graph.has_any_configurable_activation()

# Set Sensitivity Evaluator for MP search. It should always work with the original MP graph,
# even if a virtual graph was created (and is used only for BOPS utilization computation purposes)
se = fw_impl.get_sensitivity_evaluator(
graph_to_search_cfg,
graph,
mp_config,
representative_data_gen=representative_data_gen,
fw_info=fw_info,
disable_activation_for_metric=disable_activation_for_metric,
hessian_info_service=hessian_info_service)

# Instantiate a manager object
if search_method != BitWidthSearchMethod.INTEGER_PROGRAMMING:
raise NotImplementedError()

# Search manager and LP are highly coupled, so LP search method was moved inside search manager.
search_manager = MixedPrecisionSearchManager(graph,
fw_info,
fw_impl,
se,
target_resource_utilization,
original_graph=graph_to_search_cfg)

if search_method not in search_methods:
raise NotImplementedError() # pragma: no cover

search_method_fn = search_methods[search_method]
# Search for the desired mixed-precision configuration
result_bit_cfg = search_method_fn(search_manager,
target_resource_utilization)
target_resource_utilization)
result_bit_cfg = search_manager.search()

if mp_config.refine_mp_solution:
result_bit_cfg = greedy_solution_refinement_procedure(result_bit_cfg, search_manager, target_resource_utilization)
Expand Down
Loading