Skip to content
Merged
6 changes: 6 additions & 0 deletions model_compression_toolkit/core/common/framework_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
from model_compression_toolkit.defaultdict import DefaultDict


# Default value to use for ops without kernel.
# This is a weird default, but it's used all over the place, so for now only extract it to const so that it can be
# referenced by variable instead of hard-coded.
DEFAULT_KERNEL_ATTRIBUTES = [None]


class ChannelAxis(Enum):
"""

Expand Down
28 changes: 9 additions & 19 deletions model_compression_toolkit/core/common/graph/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from copy import copy, deepcopy
from functools import wraps
from typing import List, Tuple, Any, Callable
from typing import List, Tuple, Any, Callable, Dict

import networkx as nx
import numpy as np
Expand Down Expand Up @@ -683,7 +683,7 @@ def _sort_nodes_in_list(self, nodes_list: List[BaseNode]) -> List[BaseNode]:
sorted_configurable_nodes.append(n)
return sorted_configurable_nodes

def get_min_candidates_config(self, fw_info: FrameworkInfo) -> List[int]:
def get_min_candidates_config(self, fw_info: FrameworkInfo) -> Dict[BaseNode, int]:
"""
Builds a minimal configuration.
Note: we assume that a minimal configuration exists, i.e., each configurable node has exactly one candidate
Expand All @@ -693,18 +693,13 @@ def get_min_candidates_config(self, fw_info: FrameworkInfo) -> List[int]:
Args:
fw_info: fw_info: FrameworkInfo object with information about the specific framework's model.

Returns: A list of candidate for each node (list on indices)
Returns:
A dict from layer to an index of its minimal candidate.
"""

conf_sorted_nodes = self.get_configurable_sorted_nodes(fw_info)
min_cfg_candidates = [n.find_min_candidates_indices() for n in conf_sorted_nodes] # list of lists of indices

assert all([len(lst) == 1 for lst in min_cfg_candidates]), \
f"A minimal config candidate must be defined, but some node have multiple potential minimal candidates"

return [lst[0] for lst in min_cfg_candidates]
return {n: n.find_min_candidate_index() for n in conf_sorted_nodes}

def get_max_candidates_config(self, fw_info: FrameworkInfo) -> List[int]:
def get_max_candidates_config(self, fw_info: FrameworkInfo) -> Dict[BaseNode, int]:
"""
Builds a maximal configuration.
Note: we assume that a maximal configuration exists, i.e., each configurable node has exactly one candidate
Expand All @@ -714,16 +709,11 @@ def get_max_candidates_config(self, fw_info: FrameworkInfo) -> List[int]:
Args:
fw_info: fw_info: FrameworkInfo object with information about the specific framework's model.

Returns: A list of candidate for each node (list on indices)
Returns:
A dict from layer to an index of its maximal candidate.
"""

conf_sorted_nodes = self.get_configurable_sorted_nodes(fw_info)
max_cfg_candidates = [n.find_max_candidates_indices() for n in conf_sorted_nodes] # list of lists of indices

assert all([len(lst) == 1 for lst in max_cfg_candidates]), \
f"A maximal config candidate must be defined, but some node have multiple potential maximal candidates"

return [lst[0] for lst in max_cfg_candidates]
return {n: n.find_max_candidate_index() for n in conf_sorted_nodes}

def get_final_weights_config(self, fw_info: FrameworkInfo) -> List[Tuple[BaseNode, int]]:
"""
Expand Down
64 changes: 25 additions & 39 deletions model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,49 +488,35 @@ def get_total_output_params(self) -> float:
# for scalar shape (None,) prod returns 1
return sum([np.prod([x for x in output_shape if x is not None]) for output_shape in output_shapes])

def find_min_candidates_indices(self) -> List[int]:
def find_min_candidate_index(self) -> int:
"""
Returns a list with potential minimal candidates.
A potential minimal candidate is a candidate which its weights_n_bits and activation_n_bits pair is
on the Pareto Front, i.e., there is no other candidate that its n_bits pair exceeds in both entries.

Returns: A list of indices of potential minimal candidates.

"""

# We assume that the candidates are sorted according to weights_n_bits first and activation_n_bits second
# First, we add the last candidate to the set of minimal candidates (candidate, index)
first_min = (len(self.candidates_quantization_cfg) - 1,
self.candidates_quantization_cfg[-1].activation_quantization_cfg.activation_n_bits)
min_candidates = [first_min]

# Iterate over all other candidates, and add ones with higher weights_n_bits but smaller activation_n_bits
for i, c in reversed(list(enumerate(self.candidates_quantization_cfg))):
if c.activation_quantization_cfg.activation_n_bits < first_min[1]:
min_candidates.append((i, c))

return [i for i, a_n_bits in min_candidates]

def find_max_candidates_indices(self) -> List[int]:
Returns:
The index of the minimal bit-width candidate.
"""
Returns a list with potential maximal candidates.
A potential maximal candidate is a candidate which its weights_n_bits and activation_n_bits pair is
on the Pareto Front, i.e., there is no other candidates that its n_bits pair is lower in both entries.
aw_nbits = [(c.activation_quantization_cfg.activation_n_bits,
*[v.weights_n_bits for v in c.weights_quantization_cfg.get_all_weight_attrs_configs().values()])
for c in self.candidates_quantization_cfg]
min_nbits = min(aw_nbits)
min_ind = [i for i, nb in enumerate(aw_nbits) if min_nbits == nb]
# check that no other candidate has a lower nbit for any weight
if len(min_ind) > 1 or any(nb[i] < min_nbits[i] for i in range(len(min_nbits)) for nb in aw_nbits):
raise ValueError('Expected exactly one candidate with min activation and min weights.')
return min_ind[0]

Returns: A list of indices of potential maximal candidates.
def find_max_candidate_index(self) -> int:
"""

# We assume that the candidates are sorted according to weights_n_bits first and activation_n_bits second
# First, we add the first candidate to the set of maximal candidates (candidate, index)
first_max = (0, self.candidates_quantization_cfg[0].activation_quantization_cfg.activation_n_bits)
max_candidates = [first_max]

# Iterate over all other candidates, and add ones with higher weights_n_bits but smaller activation_n_bits
for i, c in enumerate(self.candidates_quantization_cfg):
if c.activation_quantization_cfg.activation_n_bits > first_max[1]:
max_candidates.append((i, c))

return [i for i, a_n_bits in max_candidates]
Returns:
The index of the maximal bit-width candidate.
"""
aw_nbits = [(c.activation_quantization_cfg.activation_n_bits,
*[v.weights_n_bits for v in c.weights_quantization_cfg.get_all_weight_attrs_configs().values()])
for c in self.candidates_quantization_cfg]
max_nbits = max(aw_nbits)
max_ind = [i for i, nb in enumerate(aw_nbits) if max_nbits == nb]
# check that no other candidate has a higher nbit for any weight
if len(max_ind) > 1 or any(nb[i] > max_nbits[i] for i in range(len(max_nbits)) for nb in aw_nbits):
raise ValueError('Expected exactly one candidate with max activation and max weights.')
return max_ind[0]

def get_unique_weights_candidates(self, attr: str) -> List[Any]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import List, Set, Dict, Tuple
from typing import Set, Dict, Tuple

import numpy as np

from model_compression_toolkit.core import FrameworkInfo
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common import Graph, BaseNode
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
RUTarget
Expand All @@ -36,7 +36,7 @@ def __init__(self, graph: Graph, fw_info: FrameworkInfo, fw_impl: FrameworkImple
self.fw_impl = fw_impl
self.ru_calculator = ResourceUtilizationCalculator(graph, fw_impl, fw_info)

def compute_utilization(self, ru_targets: Set[RUTarget], mp_cfg: List[int]) -> Dict[RUTarget, np.ndarray]:
def compute_utilization(self, ru_targets: Set[RUTarget], mp_cfg: Dict[BaseNode, int]) -> Dict[RUTarget, np.ndarray]:
"""
Compute utilization of requested targets for a specific configuration:
for weights and bops - total utilization,
Expand Down Expand Up @@ -74,7 +74,7 @@ def compute_utilization(self, ru_targets: Set[RUTarget], mp_cfg: List[int]) -> D
f'Requested {ru_targets}')
return ru_dict

def get_quantization_candidates(self, mp_cfg) \
def get_quantization_candidates(self, mp_cfg: Dict[BaseNode, int]) \
-> Tuple[Dict[str, NodeActivationQuantizationConfig], Dict[str, NodeWeightsQuantizationConfig]]:
"""
Retrieve quantization candidates objects for weights and activations from the configuration list.
Expand All @@ -86,8 +86,7 @@ def get_quantization_candidates(self, mp_cfg) \
A mapping between nodes to weights quantization config, and a mapping between nodes and activation
quantization config.
"""
mp_nodes = self.graph.get_configurable_sorted_nodes(self.fw_info)
node_qcs = {n: n.candidates_quantization_cfg[mp_cfg[i]] for i, n in enumerate(mp_nodes)}
node_qcs = {n: n.candidates_quantization_cfg[candidate_idx] for n, candidate_idx in mp_cfg.items()}
act_qcs = {n.name: cfg.activation_quantization_cfg for n, cfg in node_qcs.items()}
w_qcs = {n.name: cfg.weights_quantization_cfg for n, cfg in node_qcs.items()}
return act_qcs, w_qcs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
# ==============================================================================

from enum import Enum
from typing import List, Callable
from typing import List, Callable, Dict

from model_compression_toolkit.core import MixedPrecisionQuantizationConfig
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common import Graph, BaseNode
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
from model_compression_toolkit.core.common.hessian import HessianInfoService
Expand Down Expand Up @@ -100,11 +100,13 @@ def search_bit_width(graph: Graph,
fw_impl,
se,
target_resource_utilization)
result_bit_cfg = search_manager.search()
nodes_bit_cfg = search_manager.search()

graph.skip_validation_check = False

if mp_config.refine_mp_solution:
result_bit_cfg = greedy_solution_refinement_procedure(result_bit_cfg, search_manager, target_resource_utilization)
nodes_bit_cfg = greedy_solution_refinement_procedure(nodes_bit_cfg, search_manager, target_resource_utilization)

return result_bit_cfg
topo_bit_cfg = [nodes_bit_cfg[n] for n in graph.get_configurable_sorted_nodes(fw_info)]
assert len(topo_bit_cfg) == len(nodes_bit_cfg)
return topo_bit_cfg
Loading