Skip to content

Commit 018b352

Browse files
Integrate manual weights bitwidth selection into the MCT core (#1395)
* integrate weights manual selection bitwidth * fix comment in test code * Fix in set_node_quantization_config.py * debugging version(040700gouda) * debugging version(040701gouda) * delete debug-print * debugging version(040900gouda) * revert generate_test_tpc.py * fixed feedback comments. * delete old tests for manual_weights_bitwidth_selection * remove unnecessary changes. * add typehint and description. * fixed PR comments. * remove unnecessary changes(new lines). * fixed PR comments. * fixed PR comments. * fixed PR comments. (minimum tpc setting for e2e tests.) --------- Co-authored-by: gouda-youichi <gouda.youichi@jp.panasonic.com>
1 parent 9f8df99 commit 018b352

File tree

4 files changed

+591
-25
lines changed

4 files changed

+591
-25
lines changed

model_compression_toolkit/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@
6363
CORRECTED_BIAS_ATTRIBUTE = 'corrected_bias'
6464
ACTIVATION_N_BITS_ATTRIBUTE = 'activation_n_bits'
6565
SUPPORTED_INPUT_ACTIVATION_NBITS_ATTRIBUTE = 'supported_input_activation_n_bits'
66+
WEIGHTS = 'weights'
67+
ACTIVATION = 'activation'
6668

6769
# Quantization Parameters Iterative Search Defaults:
6870
SYMMETRIC_TENSOR_N_ITER = 40

model_compression_toolkit/core/common/quantization/set_node_quantization_config.py

Lines changed: 162 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
16-
1715
import copy
18-
from typing import List, Tuple, Optional
16+
from typing import List, Tuple, Dict, Optional
1917

20-
from mct_quantizers.common.constants import ACTIVATION_N_BITS
18+
from mct_quantizers.common.constants import WEIGHTS_N_BITS, ACTIVATION_N_BITS
19+
from model_compression_toolkit.constants import WEIGHTS, ACTIVATION
2120
from model_compression_toolkit.core.common import BaseNode
2221
from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
2322
from model_compression_toolkit.logger import Logger
@@ -30,14 +29,13 @@
3029
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, \
3130
QuantizationErrorMethod
3231
from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \
33-
get_activation_quantization_params_fn, get_weights_quantization_params_fn
34-
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import \
35-
get_weights_quantization_fn
32+
get_activation_quantization_params_fn
3633
from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import max_input_activation_n_bits
3734
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \
3835
QuantizationConfigOptions
3936
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
4037
FrameworkQuantizationCapabilities
38+
from model_compression_toolkit.core.common.graph.base_node import WeightAttrT
4139

4240

4341
def set_quantization_configuration_to_graph(graph: Graph,
@@ -66,16 +64,19 @@ def set_quantization_configuration_to_graph(graph: Graph,
6664
Logger.warning("Using the HMSE error method for weights quantization parameters search. "
6765
"Note: This method may significantly increase runtime during the parameter search process.")
6866

69-
nodes_to_manipulate_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_activation_bit_widths(graph)
67+
nodes_to_manipulate_activation_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_activation_bit_widths(graph)
68+
nodes_to_manipulate_weights_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_weights_bit_widths(graph)
7069

7170
for n in graph.nodes:
71+
manual_bit_width_override = {ACTIVATION: nodes_to_manipulate_activation_bit_widths.get(n),
72+
WEIGHTS: nodes_to_manipulate_weights_bit_widths.get(n)}
7273
set_quantization_configs_to_node(node=n,
7374
graph=graph,
7475
quant_config=quant_config,
7576
fw_info=graph.fw_info,
7677
fqc=graph.fqc,
7778
mixed_precision_enable=mixed_precision_enable,
78-
manual_bit_width_override=nodes_to_manipulate_bit_widths.get(n))
79+
manual_bit_width_override=manual_bit_width_override)
7980
return graph
8081

8182

@@ -151,7 +152,7 @@ def set_quantization_configs_to_node(node: BaseNode,
151152
fw_info: FrameworkInfo,
152153
fqc: FrameworkQuantizationCapabilities,
153154
mixed_precision_enable: bool = False,
154-
manual_bit_width_override: Optional[int] = None):
155+
manual_bit_width_override: Optional[Dict] = None):
155156
"""
156157
Create and set quantization configurations to a node (for both weights and activation).
157158
@@ -167,8 +168,11 @@ def set_quantization_configs_to_node(node: BaseNode,
167168
node_qc_options = node.get_qco(fqc)
168169
base_config, node_qc_options_list = filter_node_qco_by_graph(node, fqc, graph, node_qc_options)
169170

170-
# If a manual_bit_width_override is given, filter node_qc_options_list to retain only the options with activation bits equal to manual_bit_width_override,
171+
# If a manual_bit_width_override is given, filter node_qc_options_list to retain only the options with activation and weights bits equal to manual_bit_width_override,
171172
# and update base_config accordingly.
173+
if manual_bit_width_override is None:
174+
manual_bit_width_override = {ACTIVATION: None, WEIGHTS: None}
175+
172176
base_config, node_qc_options_list = filter_qc_options_with_manual_bit_width(
173177
node=node,
174178
node_qc_options_list=node_qc_options_list,
@@ -322,7 +326,7 @@ def filter_qc_options_with_manual_bit_width(
322326
node: BaseNode,
323327
node_qc_options_list: List[OpQuantizationConfig],
324328
base_config: OpQuantizationConfig,
325-
manual_bit_width_override: Optional[int],
329+
manual_bit_width_override: Optional[Dict],
326330
mixed_precision_enable: bool) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
327331
"""
328332
Update the quantization configurations for a node, allowing manual bit-width overrides if specified.
@@ -331,36 +335,169 @@ def filter_qc_options_with_manual_bit_width(
331335
node (BaseNode): A node to set quantization configuration candidates to.
332336
node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs for the node.
333337
base_config (OpQuantizationConfig): Base quantization config for the node.
334-
manual_bit_width_override (Optional[int]): Specifies a custom bit-width to override the node's activation bit-width.
338+
manual_bit_width_override (Optional[Dict]): Specifies a custom bit-width to override the node's activation and weights bit-width.
335339
mixed_precision_enable (bool): Whether mixed precision is enabled.
336340
337341
Returns:
338342
Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]: The updated base configuration and the filtered list of quantization configs.
339343
"""
340-
if manual_bit_width_override is None:
344+
base_config, node_qc_options_list = filter_activation_qc_options_with_manual_bit_width(node,
345+
node_qc_options_list,
346+
base_config,
347+
manual_bit_width_override.get(ACTIVATION),
348+
mixed_precision_enable)
349+
350+
base_config, node_qc_options_list = filter_weights_qc_options_with_manual_bit_width(node,
351+
node_qc_options_list,
352+
base_config,
353+
manual_bit_width_override.get(WEIGHTS),
354+
mixed_precision_enable)
355+
return base_config, node_qc_options_list
356+
357+
358+
def filter_activation_qc_options_with_manual_bit_width(
359+
node: BaseNode,
360+
node_qc_options_list: List[OpQuantizationConfig],
361+
base_config: OpQuantizationConfig,
362+
activation_manual_bit_width_override: Optional[int],
363+
mixed_precision_enable: bool) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
364+
"""
365+
Update the activation quantization configurations for a node, allowing manual bit-width overrides if specified.
366+
367+
Args:
368+
node (BaseNode): A node to set quantization configuration candidates to.
369+
node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs for the node.
370+
base_config (OpQuantizationConfig): Base quantization config for the node.
371+
activation_manual_bit_width_override (Optional[Dict]): Specifies a custom bit-width to override the node's activation bit-width.
372+
mixed_precision_enable (bool): Whether mixed precision is enabled.
373+
374+
Returns:
375+
Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]: The updated base configuration and the filtered list of quantization configs.
376+
"""
377+
if activation_manual_bit_width_override is None:
341378
return base_config, node_qc_options_list
342379

343-
# Filter node_qc_options_list to retain only the options with activation bits equal to manual_bit_width_override.
380+
# Filter node_qc_options_list to retain only the options with activation bits equal to activation_manual_bit_width_override.
344381
node_qc_options_list = [op_cfg for op_cfg in node_qc_options_list if
345-
manual_bit_width_override == op_cfg.activation_n_bits]
346-
382+
activation_manual_bit_width_override == op_cfg.activation_n_bits]
347383
if len(node_qc_options_list) == 0:
348-
Logger.critical(f"Manually selected activation bit-width {manual_bit_width_override} is invalid for node {node}.")
384+
Logger.critical(f"Manually selected activation bit-width {activation_manual_bit_width_override} is invalid for node {node}.")
349385
else:
350386
# Update the base_config to one of the values from the filtered node_qc_options_list.
351-
# First, check if a configuration similar to the original base_config but with activation bits equal to manual_bit_width_override exists.
387+
# First, check if a configuration similar to the original base_config but with activation bits equal to activation_manual_bit_width_override exists.
352388
# If it does, use it as the base_config. If not, choose a different configuration from node_qc_options_list.
353-
Logger.info(f"Setting node {node} bit-width to manually selected bit-width: {manual_bit_width_override} bits.")
354-
updated_base_config = base_config.clone_and_edit({ACTIVATION_N_BITS, manual_bit_width_override})
389+
Logger.info(f"Setting node {node} bit-width to manually selected bit-width: {activation_manual_bit_width_override} bits.")
390+
updated_base_config = base_config.clone_and_edit({ACTIVATION_N_BITS, activation_manual_bit_width_override})
355391
if updated_base_config in node_qc_options_list:
356-
# If a base_config with the specified manual_bit_width_override exists in the node_qc_options_list,
392+
# If a base_config with the specified activation_manual_bit_width_override exists in the node_qc_options_list,
357393
# point the base_config to this option.
358394
base_config = node_qc_options_list[node_qc_options_list.index(updated_base_config)]
359395
else:
360396
# Choose a different configuration from node_qc_options_list. If multiple options exist, issue a warning.
361397
base_config = node_qc_options_list[0]
362398
if len(node_qc_options_list) > 0 and not mixed_precision_enable:
363399
Logger.info(
364-
f"Request received to select {manual_bit_width_override} activation bits. However, the base configuration for layer type {node.type} is missing in the node_qc_options_list."
365-
f" Overriding base_config with an option that uses {manual_bit_width_override} bit activations.") # pragma: no cover
366-
return base_config, node_qc_options_list
400+
f"Request received to select {activation_manual_bit_width_override} activation bits. However, the base configuration for layer type {node.type} is missing in the node_qc_options_list."
401+
f" Overriding base_config with an option that uses {activation_manual_bit_width_override} bit activations.") # pragma: no cover
402+
403+
return base_config, node_qc_options_list
404+
405+
406+
def filter_weights_qc_options_with_manual_bit_width(
407+
node: BaseNode,
408+
node_qc_options_list: List[OpQuantizationConfig],
409+
base_config: OpQuantizationConfig,
410+
weights_manual_bit_width_override: Optional[Tuple[int, WeightAttrT]],
411+
mixed_precision_enable: bool) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
412+
"""
413+
Update the weights quantization configurations for a node, allowing manual bit-width overrides if specified.
414+
415+
Args:
416+
node (BaseNode): A node to set quantization configuration candidates to.
417+
node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs for the node.
418+
base_config (OpQuantizationConfig): Base quantization config for the node.
419+
weights_manual_bit_width_override (Optional[[int, WeightAttrT]]): Specifies a custom bit-width to override the node's weights bit-width.
420+
mixed_precision_enable (bool): Whether mixed precision is enabled.
421+
422+
Returns:
423+
Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]: The updated base configuration and the filtered list of quantization configs.
424+
"""
425+
if not weights_manual_bit_width_override:
426+
return base_config, node_qc_options_list
427+
428+
# Filter node_qc_options_list to retain only the options with weights bits equal to weights_manual_bit_width_override.
429+
node_qc_options_weights_list = _filter_options(node_qc_options_list, weights_manual_bit_width_override)
430+
431+
if len(node_qc_options_weights_list) == 0:
432+
Logger.critical(f"Manually selected weights bit-width {weights_manual_bit_width_override} is invalid for node {node}.")
433+
else:
434+
# Update the base_config to one of the values from the filtered node_qc_options_list.
435+
# First, check if a configuration similar to the original base_config but with weights bits equal to weights_manual_bit_width_override exists.
436+
# If it does, use it as the base_config. If not, choose a different configuration from node_qc_options_list.
437+
updated_base_config = base_config.clone_and_edit()
438+
439+
for bit_width, attr in weights_manual_bit_width_override:
440+
Logger.info(f"Setting node {node} bit-width to manually selected {attr} bit-width: {bit_width} bits.")
441+
updated_base_config = updated_base_config.clone_and_edit(attr_to_edit={attr : {WEIGHTS_N_BITS: bit_width}})
442+
443+
if updated_base_config in node_qc_options_weights_list:
444+
# If a base_config with the specified weights_manual_bit_width_override exists in the node_qc_options_list,
445+
# point the base_config to this option.
446+
base_config = node_qc_options_weights_list[node_qc_options_weights_list.index(updated_base_config)]
447+
else:
448+
# Choose a different configuration from node_qc_options_list. If multiple options exist, issue a warning.
449+
base_config = node_qc_options_weights_list[0]
450+
if len(node_qc_options_weights_list) > 0 and not mixed_precision_enable:
451+
Logger.info(
452+
f"Request received to select weights bit-widths {weights_manual_bit_width_override}."
453+
f"However, the base configuration for layer type {node.type} is missing in the node_qc_options_list."
454+
f" Overriding base_config with an option that uses manually selected weights bit-widths {weights_manual_bit_width_override}.") # pragma: no cover
455+
456+
return base_config, node_qc_options_weights_list
457+
458+
459+
def _is_valid_option(
460+
op_cfg: OpQuantizationConfig,
461+
attr: WeightAttrT,
462+
bit_width: int) -> bool:
463+
"""
464+
Judge whether the specified option is valid based on the specified attribute and bit width.
465+
466+
Args:
467+
op_cfg (OpQuantizationConfig): The quantization configuration to be judged.
468+
attr (WeightAttrT): The filtered node's attributes to apply bit-width manipulation to.
469+
bit_width (int): The bit width to be applied to the selected nodes.
470+
471+
Returns:
472+
Result to judge whether the specified option is valid based on the specified attribute and bit width
473+
"""
474+
weights_attrs = op_cfg.attr_weights_configs_mapping.keys()
475+
476+
if attr not in weights_attrs:
477+
return False
478+
479+
weights_n_bits = op_cfg.attr_weights_configs_mapping[attr].weights_n_bits
480+
return weights_n_bits == bit_width
481+
482+
483+
def _filter_options(
484+
node_qc_options_list: List[OpQuantizationConfig],
485+
weights_manual_bit_width_override: Tuple[int, WeightAttrT]) -> List[OpQuantizationConfig]:
486+
"""
487+
Filter the options based on the specified bit width and attribute.
488+
489+
Args:
490+
node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs for the node.
491+
weights_manual_bit_width_override (Tuple[int, WeightAttrT])): Specifies a custom bit-width to override the node's weights bit-width.
492+
493+
Returns:
494+
List[OpQuantizationConfig]: Filtered the options based on the specified bit width and attribute.
495+
"""
496+
filtered_options = []
497+
498+
for bit_width, attr in weights_manual_bit_width_override:
499+
for op_cfg in node_qc_options_list:
500+
if _is_valid_option(op_cfg, attr, bit_width):
501+
filtered_options.append(op_cfg)
502+
503+
return filtered_options

0 commit comments

Comments
 (0)