Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from model_compression_toolkit.core.common.mixed_precision.sensitivity_eval.sensitivity_evaluation import SensitivityEvaluation
from model_compression_toolkit.core.common.mixed_precision.solution_refinement_procedure import \
greedy_solution_refinement_procedure
from model_compression_toolkit.core.common.progress_config.progress_info_controller import \
ProgressInfoController


class BitWidthSearchMethod(Enum):
Expand All @@ -41,7 +43,8 @@ def search_bit_width(graph: Graph,
mp_config: MixedPrecisionQuantizationConfig,
representative_data_gen: Callable,
search_method: BitWidthSearchMethod = BitWidthSearchMethod.INTEGER_PROGRAMMING,
hessian_info_service: HessianInfoService = None) -> List[int]:
hessian_info_service: HessianInfoService = None,
progress_info_controller: ProgressInfoController = None) -> List[int]:
"""
Search for an MP configuration for a given graph. Given a search_method method (by default, it's linear
programming), we use the sensitivity_evaluator object that provides a function to compute an
Expand All @@ -59,6 +62,7 @@ def search_bit_width(graph: Graph,
representative_data_gen: Dataset to use for retrieving images for the models inputs.
search_method: BitWidthSearchMethod to define which searching method to use.
hessian_info_service: HessianInfoService to fetch Hessian-approximation information.
progress_info_controller: ProgressInfoController to display and manage overall progress information.

Returns:
A MP configuration for the graph (list of integers, where the index in the list, is the node's
Expand All @@ -81,7 +85,8 @@ def search_bit_width(graph: Graph,
# even if a virtual graph was created (and is used only for BOPS utilization computation purposes)
se = SensitivityEvaluation(graph, mp_config, representative_data_gen=representative_data_gen, fw_info=fw_info,
fw_impl=fw_impl, disable_activation_for_metric=disable_activation_for_metric,
hessian_info_service=hessian_info_service)
hessian_info_service=hessian_info_service,
progress_info_controller=progress_info_controller)

if search_method != BitWidthSearchMethod.INTEGER_PROGRAMMING:
raise NotImplementedError()
Expand All @@ -97,7 +102,8 @@ def search_bit_width(graph: Graph,
fw_impl=fw_impl,
sensitivity_evaluator=se,
target_resource_utilization=target_resource_utilization,
mp_config=mp_config)
mp_config=mp_config,
progress_info_controller=progress_info_controller)
nodes_bit_cfg = search_manager.search()

graph.skip_validation_check = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
MixedPrecisionQuantizationConfig, MpMetricNormalization
from model_compression_toolkit.core.common.progress_config.progress_info_controller import \
ProgressInfoController


class MixedPrecisionSearchManager:
Expand All @@ -57,7 +59,8 @@ def __init__(self,
fw_impl: FrameworkImplementation,
sensitivity_evaluator: SensitivityEvaluation,
target_resource_utilization: ResourceUtilization,
mp_config: MixedPrecisionQuantizationConfig):
mp_config: MixedPrecisionQuantizationConfig,
progress_info_controller: ProgressInfoController = None):
"""

Args:
Expand All @@ -67,11 +70,14 @@ def __init__(self,
sensitivity_evaluator: A SensitivityEvaluation which provides a function that evaluates the sensitivity of
a bit-width configuration for the MP model.
target_resource_utilization: Target Resource Utilization to bound our feasible solution space s.t the configuration does not violate it.
progress_info_controller: ProgressInfoController to display and manage overall progress information.
"""

self.fw_info = fw_info
self.fw_impl = fw_impl

self.progress_info_controller = progress_info_controller

self.original_graph = graph
# graph for mp search
self.mp_graph, self.using_virtual_graph = self._get_mp_graph(graph, target_resource_utilization)
Expand Down Expand Up @@ -183,6 +189,9 @@ def ensure_maxbit_minimal_metric(node_candidates_metrics, max_ind):
metrics[max_ind] = max_val
return metrics

if self.progress_info_controller is not None:
self.progress_info_controller.set_description('Research Mixed Precision')

layer_to_metrics_mapping = {}
debug_mapping = {}
for node_idx, node in tqdm(enumerate(self.mp_topo_configurable_nodes)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
from model_compression_toolkit.core.common.similarity_analyzer import compute_kl_divergence
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core.common.progress_config.progress_info_controller import \
ProgressInfoController


@runtime_checkable
Expand Down Expand Up @@ -64,7 +66,8 @@ def __init__(self,
representative_data_gen: Callable,
fw_info: FrameworkInfo,
fw_impl: Any,
hessian_info_service: HessianInfoService = None):
hessian_info_service: HessianInfoService = None,
progress_info_controller: ProgressInfoController = None):
"""
Args:
graph: Graph to search for its MP configuration.
Expand All @@ -74,6 +77,7 @@ def __init__(self,
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
representative_data_gen: Dataset used for getting batches for inference.
hessian_info_service: HessianInfoService to fetch Hessian approximation information.
progress_info_controller: ProgressInfoController to display and manage overall progress information.
"""
self.graph = graph
self.mp_config = mp_config
Expand Down Expand Up @@ -121,7 +125,7 @@ def __init__(self,
# Hessian-based scores for weighted average distance metric computation
self.interest_points_hessians = None
if self.mp_config.distance_weighting_method == MpDistanceWeighting.HESSIAN:
self.interest_points_hessians = self._compute_hessian_based_scores(hessian_info_service)
self.interest_points_hessians = self._compute_hessian_based_scores(hessian_info_service, progress_info_controller)

def compute(self, mp_model) -> float:
"""
Expand Down Expand Up @@ -168,16 +172,20 @@ def _init_baseline_tensors_list(self):
return [self.fw_impl.to_numpy(self.fw_impl.sensitivity_eval_inference(self.ref_model, images))
for images in self.images_batches]

def _compute_hessian_based_scores(self, hessian_info_service: HessianInfoService) -> np.ndarray:
def _compute_hessian_based_scores(self, hessian_info_service: HessianInfoService, progress_info_controller: ProgressInfoController) -> np.ndarray:
"""
Compute Hessian-based scores for each interest point.
Args:
hessian_info_service: Hessian service.
progress_info_controller: Progress infomation controller.

Returns:
A vector of scores, one for each interest point, to be used for the distance metric weighted average computation.

"""
if progress_info_controller is not None:
progress_info_controller.set_description('Compute Hessian for Mixed Precision')

# Create a request for Hessian approximation scores with specific configurations
# (here we use per-tensor approximation of the Hessian's trace w.r.t the node's activations)
fw_dataloader = self.fw_impl.convert_data_gen_to_dataloader(self.representative_data_gen,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
from model_compression_toolkit.core.common.hessian import HessianInfoService
from model_compression_toolkit.core.common.progress_config.progress_info_controller import \
ProgressInfoController


class SensitivityEvaluation:
Expand All @@ -41,7 +43,8 @@ def __init__(self,
fw_info: FrameworkInfo,
fw_impl: Any,
disable_activation_for_metric: bool = False,
hessian_info_service: HessianInfoService = None
hessian_info_service: HessianInfoService = None,
progress_info_controller: ProgressInfoController = None
):
"""
Args:
Expand All @@ -53,7 +56,7 @@ def __init__(self,
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
disable_activation_for_metric: Whether to disable activation quantization when computing the MP metric.
hessian_info_service: HessianInfoService to fetch Hessian approximation information.

progress_info_controller: ProgressInfoController to display and manage overall progress information.
"""
self.mp_config = mp_config
self.representative_data_gen = representative_data_gen
Expand All @@ -65,7 +68,8 @@ def __init__(self,
else:
self.metric_calculator = DistanceMetricCalculator(graph, mp_config, representative_data_gen,
fw_info=fw_info, fw_impl=fw_impl,
hessian_info_service=hessian_info_service)
hessian_info_service=hessian_info_service,
progress_info_controller=progress_info_controller)

# Build a mixed-precision model which can be configured to use different bitwidth in different layers.
# Also, returns a mapping between a configurable graph's node and its matching layer(s) in the built MP model.
Expand Down
14 changes: 14 additions & 0 deletions model_compression_toolkit/core/common/progress_config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2026 Sony Semiconductor Solutions, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
24 changes: 24 additions & 0 deletions model_compression_toolkit/core/common/progress_config/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright 2026 Sony Semiconductor Solutions, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

COMPLETED_COMPONENTS = 'completedComponents'
TOTAL_COMPONENTS = 'totalComponents'
CURRENT_COMPONENT = 'currentComponent'

PROGRESS_INFO_CALLBACK = 'progress_info_callback'
TOTAL_STEP = 'total_step'

PROGRESS_BAR_POSITION = 2
DEFAULT_TOTAL_STEP = 4
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@

# Copyright 2026 Sony Semiconductor Solutions, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Optional, Callable, TYPE_CHECKING
from dataclasses import dataclass, field
from tqdm import tqdm

from model_compression_toolkit.core.common.progress_config.constants import (
COMPLETED_COMPONENTS, TOTAL_COMPONENTS, CURRENT_COMPONENT,
PROGRESS_BAR_POSITION, PROGRESS_INFO_CALLBACK, TOTAL_STEP, DEFAULT_TOTAL_STEP
)

if TYPE_CHECKING: # pragma: no cover
from model_compression_toolkit.core import CoreConfig
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization


@dataclass
class ProgressInfoController:
"""
A unified progress bar controller class.
Support single progress bar.

Attributes:
total_step: Total number of processing steps.
description: Description for the progress bar.
current_step: Current step number (starts from 0, incremented by set_description()).
callback: User-defined callback function.
"""
total_step: int = field(default=0)
current_step: int = field(default=0)
description: str = field(default="Model Compression Toolkit Progress Infomation")
progress_info_callback: Optional[Callable] = field(default=None)

def __new__(cls, *args, **kwargs):
"""
Create or skip instantiation based on the enable flag.
Returns None when progress display should be disabled.
"""
progress_info_callback = kwargs.get(PROGRESS_INFO_CALLBACK)
total_step = kwargs.get(TOTAL_STEP)

if progress_info_callback is None or total_step <= 0:
return None

if not callable(progress_info_callback):
raise TypeError(f"{PROGRESS_INFO_CALLBACK} must be a callable (function or callable instance).")

return super().__new__(cls)

def __post_init__(self):
"""Create progress bar after initialization."""
# Initial single bar mode
self.pbar = tqdm(
total=self.total_step,
desc=self.description,
position=PROGRESS_BAR_POSITION,
leave=False,
unit='step',
dynamic_ncols=True,
bar_format='{l_bar}{bar:}|'
)

def set_description(self, description: str):
"""
Update progress bar description.
Automatically increments step number each time set_description is called,
displaying in "Step X/Y: ..." format.

Args:
description: New description text ("Step X/Y: " is automatically added).
"""
self.description = description
self.current_step += 1
formatted_description = f"Step {self.current_step}/{self.total_step}: {description}"

try:
assert self.current_step <= self.total_step, \
f"current_step: {self.current_step}, exceeded total_step: {self.total_step}."
except AssertionError:
self.close()
raise

self.pbar.set_description(formatted_description, refresh=False)
self.pbar.update()

progress_info = {
COMPLETED_COMPONENTS: description,
TOTAL_COMPONENTS: self.total_step,
CURRENT_COMPONENT: self.current_step
}
self.progress_info_callback(progress_info)

def close(self):
"""Close progress bar."""
if self.pbar is not None:
self.pbar.close()
self.pbar = None


def research_progress_total(core_config: 'CoreConfig',
target_resource_utilization: 'ResourceUtilization' = None,
gptq_config: 'GradientPTQConfig' = None) -> int:
"""
Check whether specific processing will be executed based on input arguments
and calculate the total number of processing steps.

Processing step breakdown:
1. Preprocessing (required)
2. Statistics calculation (required)
3. Weight parameter calculation (required)
4. Hessian calculation (when GPTQ or specific settings enabled)
5. MP calculation (when Mixed Precision enabled)
6. Post-processing ~ conversion to exportable model (required)

Args:
core_config: CoreConfig object.
target_resource_utilization: ResourceUtilization object (used for Mixed Precision determination).
gptq_config: GPTQ configuration object.

Returns:
Total number of processing steps.
"""
# Base required steps: preprocessing, statistics, weight params, post-processing
total_steps = DEFAULT_TOTAL_STEP

# Add MP calculation step (when Mixed Precision enabled)
if target_resource_utilization is not None and \
target_resource_utilization.is_any_restricted():
total_steps += 1

# Add Hessian step (when Mixed Precision with Hessian enabled)
if core_config.mixed_precision_config is not None and \
core_config.mixed_precision_config.use_hessian_based_scores:
total_steps += 1

# Add GPTQ training step (when GPTQ is enabled)
if gptq_config is not None:
total_steps += 1

# Add Hessian step (when GPTQ with Hessian enabled)
if gptq_config.hessian_weights_config is not None:
total_steps += 1

return total_steps
Loading