Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion neural_compressor/common/base_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Sized, Tuple, Union

from neural_compressor.common.base_config import BaseConfig
from neural_compressor.common.utils import TuningLogger, logger
from neural_compressor.common.utils import Statistics, TuningLogger, logger

__all__ = [
"Evaluator",
Expand Down Expand Up @@ -423,6 +423,47 @@ def add_trial_result(self, trial_index: int, trial_result: Union[int, float], qu
trial_record = _TrialRecord(trial_index, trial_result, quant_config)
self.tuning_history.append(trial_record)

# Print tuning results table
self._print_trial_results_table(trial_index, trial_result)

def _print_trial_results_table(self, trial_index: int, trial_result: Union[int, float]) -> None:
"""Print trial results in a formatted table using Statistics class."""
baseline_val = self.baseline if self.baseline is not None else 0.0
baseline_str = f"{baseline_val:.4f}" if self.baseline is not None else "N/A"
target_threshold_str = (
f"{baseline_val * (1 - self.tuning_config.tolerable_loss):.4f}" if self.baseline is not None else "N/A"
)

# Calculate relative loss if baseline is available
relative_loss_val = 0.0
relative_loss_str = "N/A"
if self.baseline is not None:
relative_loss_val = (baseline_val - trial_result) / baseline_val
relative_loss_str = f"{relative_loss_val*100:.2f}%"

# Get best result so far
best_result = max(record.trial_result for record in self.tuning_history)

# Status indicator with emoji
if self.baseline is not None and trial_result >= (baseline_val * (1 - self.tuning_config.tolerable_loss)):
status = "✅ PASSED"
else:
status = "❌ FAILED"

# Prepare data for Statistics table with combined fields
field_names = ["📊 Metric", "📈 Value"]
output_data = [
["Trial / Progress", f"{len(self.tuning_history)}/{self.tuning_config.max_trials}"],
["Baseline / Target", f"{baseline_str} / {target_threshold_str}"],
["Current / Status", f"{trial_result:.4f} | {status}"],
["Best / Relative Loss", f"{best_result:.4f} / {relative_loss_str}"],
]

# Use Statistics class to print the table
Statistics(
output_data, header=f"🎯 Auto-Tune Trial #{trial_index} Results", field_names=field_names
).print_stat()

def set_baseline(self, baseline: float):
"""Set the baseline value for auto-tune.

Expand Down Expand Up @@ -488,4 +529,10 @@ def init_tuning(tuning_config: TuningConfig) -> Tuple[ConfigLoader, TuningLogger
config_loader = ConfigLoader(config_set=tuning_config.config_set, sampler=tuning_config.sampler)
tuning_logger = TuningLogger()
tuning_monitor = TuningMonitor(tuning_config)

# Update max_trials based on actual number of available configurations
actual_config_count = len(config_loader.config_set)
if tuning_config.max_trials > actual_config_count:
tuning_config.max_trials = actual_config_count

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please raise a message to let the user be aware of that change, others LGTM. Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_trails=100 is the default, which means this message will happen in most cases. To avoid distracting users, I recommend not printing information.

return config_loader, tuning_logger, tuning_monitor
Loading