Skip to content

Commit 20c321c

Browse files
authored
MP cleanup and partial refactoring (#1391)
* move virtual graph creation and search method call inside MP manager * move sensitivity computation from linear_programming to MP search manager * remove separate computation for non-configurable nodes * simplify ru constraints construction in mixed precision * convert LP functions into class, remove dependency on MPSearchManager, call LP from MPSearchManager with precomputed metrics
1 parent 4417cbb commit 20c321c

File tree

16 files changed

+471
-643
lines changed

16 files changed

+471
-643
lines changed

model_compression_toolkit/core/common/graph/base_graph.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -706,14 +706,24 @@ def update_fused_nodes(self, fusion: List[Any]):
706706
"""
707707
self.fused_nodes.append(fusion)
708708

709-
def is_single_activation_cfg(self):
709+
def has_any_configurable_activation(self) -> bool:
710710
"""
711-
Checks whether all nodes in the graph that have activation quantization are quantized with the same bit-width.
711+
Checks whether any node in the graph has a configurable activation quantization.
712712
713-
Returns: True if all quantization config candidates of all nodes have the same activation quantization bit-width.
713+
Returns:
714+
Whether any node in the graph has a configurable activation quantization.
715+
"""
716+
return any([n.has_configurable_activation() for n in self.nodes])
717+
718+
def has_any_configurable_weights(self):
719+
"""
720+
Checks whether any node in the graph has any configurable weights quantization.
714721
722+
Returns:
723+
Whether any node in the graph has any configurable weights quantization.
715724
"""
716-
return all([n.is_all_activation_candidates_equal() for n in self.nodes])
725+
726+
return any([n.has_any_configurable_weight() for n in self.nodes])
717727

718728
def replace_node(self, node_to_replace: BaseNode, new_node: BaseNode):
719729
"""

model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py

Lines changed: 32 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
from typing import List, Set, Dict, Optional, Tuple, Any
15+
from typing import List, Set, Dict, Tuple
1616

1717
import numpy as np
1818

1919
from model_compression_toolkit.core import FrameworkInfo
20-
from model_compression_toolkit.core.common import Graph, BaseNode
20+
from model_compression_toolkit.core.common import Graph
2121
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
2222
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
2323
RUTarget
@@ -36,42 +36,46 @@ def __init__(self, graph: Graph, fw_info: FrameworkInfo, fw_impl: FrameworkImple
3636
self.fw_impl = fw_impl
3737
self.ru_calculator = ResourceUtilizationCalculator(graph, fw_impl, fw_info)
3838

39-
def compute_utilization(self, ru_targets: Set[RUTarget], mp_cfg: Optional[List[int]]) -> Dict[RUTarget, np.ndarray]:
39+
def compute_utilization(self, ru_targets: Set[RUTarget], mp_cfg: List[int]) -> Dict[RUTarget, np.ndarray]:
4040
"""
41-
Compute utilization of requested targets for a specific configuration in the format expected by LP problem
42-
formulation namely a vector of ru values for relevant memory elements (nodes or cuts) in a constant order
43-
(between calls).
41+
Compute utilization of requested targets for a specific configuration:
42+
for weights and bops - total utilization,
43+
for activations and total - utilization per cut.
4444
4545
Args:
4646
ru_targets: resource utilization targets to compute.
4747
mp_cfg: a list of candidates indices for configurable layers.
4848
4949
Returns:
50-
Dict of the computed utilization per target.
50+
Dict of the computed utilization per target, as 1d vector.
5151
"""
52-
53-
ru = {}
54-
act_qcs, w_qcs = self.get_quantization_candidates(mp_cfg) if mp_cfg else (None, None)
55-
if RUTarget.WEIGHTS in ru_targets:
56-
wu = self._weights_utilization(w_qcs)
57-
ru[RUTarget.WEIGHTS] = np.array(list(wu.values()))
58-
59-
if RUTarget.ACTIVATION in ru_targets:
60-
au = self._activation_utilization(act_qcs)
61-
ru[RUTarget.ACTIVATION] = np.array(list(au.values()))
62-
63-
if RUTarget.BOPS in ru_targets:
64-
ru[RUTarget.BOPS] = self._bops_utilization(act_qcs=act_qcs, w_qcs=w_qcs)
65-
66-
if RUTarget.TOTAL in ru_targets:
67-
raise ValueError('Total target should be computed based on weights and activations targets.')
68-
69-
assert len(ru) == len(ru_targets), (f'Mismatch between the number of computed and requested metrics.'
70-
f'Requested {ru_targets}')
71-
return ru
52+
act_qcs, w_qcs = self.get_quantization_candidates(mp_cfg)
53+
54+
ru, detailed_ru = self.ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantized,
55+
BitwidthMode.QCustom,
56+
act_qcs=act_qcs,
57+
w_qcs=w_qcs,
58+
ru_targets=ru_targets,
59+
allow_unused_qcs=True,
60+
return_detailed=True)
61+
62+
ru_dict = {k: np.array([v]) for k, v in ru.get_resource_utilization_dict(restricted_only=True).items()}
63+
# For activation and total we need utilization per cut, as different mp configurations might result in
64+
# different cuts to be maximal.
65+
for target in [RUTarget.ACTIVATION, RUTarget.TOTAL]:
66+
if target in ru_dict:
67+
ru_dict[target] = np.array(list(detailed_ru[target].values()))
68+
69+
assert all(v.ndim == 1 for v in ru_dict.values())
70+
if RUTarget.ACTIVATION in ru_targets and RUTarget.TOTAL in ru_targets:
71+
assert ru_dict[RUTarget.ACTIVATION].shape == ru_dict[RUTarget.TOTAL].shape
72+
73+
assert len(ru_dict) == len(ru_targets), (f'Mismatch between the number of computed and requested metrics.'
74+
f'Requested {ru_targets}')
75+
return ru_dict
7276

7377
def get_quantization_candidates(self, mp_cfg) \
74-
-> Tuple[Dict[BaseNode, NodeActivationQuantizationConfig], Dict[BaseNode, NodeWeightsQuantizationConfig]]:
78+
-> Tuple[Dict[str, NodeActivationQuantizationConfig], Dict[str, NodeWeightsQuantizationConfig]]:
7579
"""
7680
Retrieve quantization candidates objects for weights and activations from the configuration list.
7781
@@ -87,71 +91,3 @@ def get_quantization_candidates(self, mp_cfg) \
8791
act_qcs = {n.name: cfg.activation_quantization_cfg for n, cfg in node_qcs.items()}
8892
w_qcs = {n.name: cfg.weights_quantization_cfg for n, cfg in node_qcs.items()}
8993
return act_qcs, w_qcs
90-
91-
def _weights_utilization(self, w_qcs: Optional[Dict[BaseNode, NodeWeightsQuantizationConfig]]) -> Dict[BaseNode, float]:
92-
"""
93-
Compute weights utilization for configurable weights if configuration is passed,
94-
or for non-configurable nodes otherwise.
95-
96-
Args:
97-
w_qcs: nodes quantization configuration to compute, or None.
98-
99-
Returns:
100-
Weight utilization per node.
101-
"""
102-
if w_qcs:
103-
target_criterion = TargetInclusionCriterion.QConfigurable
104-
bitwidth_mode = BitwidthMode.QCustom
105-
else:
106-
target_criterion = TargetInclusionCriterion.QNonConfigurable
107-
bitwidth_mode = BitwidthMode.QDefaultSP
108-
109-
_, nodes_util, _ = self.ru_calculator.compute_weights_utilization(target_criterion=target_criterion,
110-
bitwidth_mode=bitwidth_mode,
111-
w_qcs=w_qcs)
112-
nodes_util = {n: u.bytes for n, u in nodes_util.items()}
113-
return nodes_util
114-
115-
def _activation_utilization(self, act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]]) \
116-
-> Optional[Dict[Any, float]]:
117-
"""
118-
Compute activation utilization using MaxCut for all quantized nodes if configuration is passed.
119-
120-
Args:
121-
act_qcs: nodes activation configuration or None.
122-
123-
Returns:
124-
Activation utilization per cut, or empty dict if no configuration was passed.
125-
"""
126-
# Maxcut activation utilization is computed for all quantized nodes, so non-configurable memory is already
127-
# covered by the computation of configurable activations.
128-
if not act_qcs:
129-
return {}
130-
131-
_, cuts_util, *_ = self.ru_calculator.compute_activation_utilization_by_cut(
132-
TargetInclusionCriterion.AnyQuantized, bitwidth_mode=BitwidthMode.QCustom, act_qcs=act_qcs)
133-
cuts_util = {c: u.bytes for c, u in cuts_util.items()}
134-
return cuts_util
135-
136-
def _bops_utilization(self,
137-
act_qcs: Optional[Dict[BaseNode, NodeActivationQuantizationConfig]],
138-
w_qcs: Optional[Dict[BaseNode, NodeWeightsQuantizationConfig]]) -> np.ndarray:
139-
"""
140-
Computes a resource utilization vector with the respective bit-operations (BOPS) count
141-
according to the given mixed-precision configuration.
142-
143-
Args:
144-
act_qcs: nodes activation configuration or None.
145-
w_qcs: nodes quantization configuration to compute, or None.
146-
Either both are provided, or both are None.
147-
148-
Returns:
149-
A vector of node's BOPS count.
150-
"""
151-
assert [act_qcs, w_qcs].count(None) in [0, 2], 'act_qcs and w_qcs should both be provided or both be None.'
152-
if act_qcs is None:
153-
return np.array([])
154-
155-
_, detailed_bops = self.ru_calculator.compute_bops(TargetInclusionCriterion.Any, BitwidthMode.QCustom,
156-
act_qcs=act_qcs, w_qcs=w_qcs)
157-
return np.array(list(detailed_bops.values()))

model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py

Lines changed: 17 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -13,37 +13,27 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16-
import copy
1716
from enum import Enum
18-
import numpy as np
19-
from typing import List, Callable, Dict
17+
from typing import List, Callable
2018

2119
from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
2220
from model_compression_toolkit.core.common import Graph
23-
from model_compression_toolkit.core.common.hessian import HessianInfoService
24-
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization, RUTarget
2521
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
26-
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_manager import MixedPrecisionSearchManager
27-
from model_compression_toolkit.core.common.mixed_precision.search_methods.linear_programming import \
28-
mp_integer_programming_search
2922
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
23+
from model_compression_toolkit.core.common.hessian import HessianInfoService
24+
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_search_manager import \
25+
MixedPrecisionSearchManager
26+
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
27+
ResourceUtilization
3028
from model_compression_toolkit.core.common.mixed_precision.solution_refinement_procedure import \
3129
greedy_solution_refinement_procedure
32-
from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute
33-
from model_compression_toolkit.logger import Logger
3430

3531

3632
class BitWidthSearchMethod(Enum):
37-
# When adding a new search_methods MP configuration method, these enum and factory dictionary
38-
# should be updated with it's kind and a search_method implementation.
3933
INTEGER_PROGRAMMING = 0
4034

4135

42-
search_methods = {
43-
BitWidthSearchMethod.INTEGER_PROGRAMMING: mp_integer_programming_search}
44-
45-
46-
def search_bit_width(graph_to_search_cfg: Graph,
36+
def search_bit_width(graph: Graph,
4737
fw_info: FrameworkInfo,
4838
fw_impl: FrameworkImplementation,
4939
target_resource_utilization: ResourceUtilization,
@@ -60,7 +50,7 @@ def search_bit_width(graph_to_search_cfg: Graph,
6050
target_resource_utilization have to be passed. If it was not passed, the facade is not supposed to get here by now.
6151
6252
Args:
63-
graph_to_search_cfg: Graph to search a MP configuration for.
53+
graph: Graph to search a MP configuration for.
6454
fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize).
6555
fw_impl: FrameworkImplementation object with specific framework methods implementation.
6656
target_resource_utilization: Target Resource Utilization to bound our feasible solution space s.t the configuration does not violate it.
@@ -75,51 +65,36 @@ def search_bit_width(graph_to_search_cfg: Graph,
7565
bit-width index on the node).
7666
7767
"""
78-
79-
# target_resource_utilization have to be passed. If it was not passed, the facade is not supposed to get here by now.
80-
if target_resource_utilization is None:
81-
Logger.critical("Target ResourceUtilization is required for the bit-width search method's configuration.") # pragma: no cover
82-
83-
# Set graph for MP search
84-
graph = copy.deepcopy(graph_to_search_cfg) # Copy graph before searching
85-
if target_resource_utilization.bops_restricted():
86-
# TODO: we only need the virtual graph is both activations and weights are configurable
87-
# Since Bit-operations count target resource utilization is set, we need to reconstruct the graph for the MP search
88-
graph = substitute(graph, fw_impl.get_substitutions_virtual_weights_activation_coupling())
68+
assert target_resource_utilization.is_any_restricted()
8969

9070
# If we only run weights compression with MP than no need to consider activation quantization when computing the
9171
# MP metric (it adds noise to the computation)
9272
tru = target_resource_utilization
9373
weight_only_restricted = tru.weight_restricted() and not (tru.activation_restricted() or
9474
tru.total_mem_restricted() or
9575
tru.bops_restricted())
96-
disable_activation_for_metric = weight_only_restricted or graph_to_search_cfg.is_single_activation_cfg()
76+
disable_activation_for_metric = weight_only_restricted or not graph.has_any_configurable_activation()
9777

9878
# Set Sensitivity Evaluator for MP search. It should always work with the original MP graph,
9979
# even if a virtual graph was created (and is used only for BOPS utilization computation purposes)
10080
se = fw_impl.get_sensitivity_evaluator(
101-
graph_to_search_cfg,
81+
graph,
10282
mp_config,
10383
representative_data_gen=representative_data_gen,
10484
fw_info=fw_info,
10585
disable_activation_for_metric=disable_activation_for_metric,
10686
hessian_info_service=hessian_info_service)
10787

108-
# Instantiate a manager object
88+
if search_method != BitWidthSearchMethod.INTEGER_PROGRAMMING:
89+
raise NotImplementedError()
90+
91+
# Search manager and LP are highly coupled, so LP search method was moved inside search manager.
10992
search_manager = MixedPrecisionSearchManager(graph,
11093
fw_info,
11194
fw_impl,
11295
se,
113-
target_resource_utilization,
114-
original_graph=graph_to_search_cfg)
115-
116-
if search_method not in search_methods:
117-
raise NotImplementedError() # pragma: no cover
118-
119-
search_method_fn = search_methods[search_method]
120-
# Search for the desired mixed-precision configuration
121-
result_bit_cfg = search_method_fn(search_manager,
122-
target_resource_utilization)
96+
target_resource_utilization)
97+
result_bit_cfg = search_manager.search()
12398

12499
if mp_config.refine_mp_solution:
125100
result_bit_cfg = greedy_solution_refinement_procedure(result_bit_cfg, search_manager, target_resource_utilization)

0 commit comments

Comments
 (0)