diff --git a/optimum/commands/neuron/cache.py b/optimum/commands/neuron/cache.py index cb437c347..e3e90c0b2 100644 --- a/optimum/commands/neuron/cache.py +++ b/optimum/commands/neuron/cache.py @@ -147,7 +147,7 @@ def _list_entries(self): str(entry["batch_size"]), str(entry["sequence_length"]), str(entry.get("tp_degree", entry.get("tensor_parallel_size"))), - str(entry["torch_dtype"]), + str(entry.get("torch_dtype", entry.get("dtype"))), str(entry["target"]), ) ) diff --git a/optimum/neuron/trainers/metrics/__init__.py b/optimum/neuron/trainers/metrics/__init__.py new file mode 100644 index 000000000..d7723edfa --- /dev/null +++ b/optimum/neuron/trainers/metrics/__init__.py @@ -0,0 +1,17 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .collector import TrainingMetricsCollector +from .window import MovingAverageWindow diff --git a/optimum/neuron/trainers/metrics/base.py b/optimum/neuron/trainers/metrics/base.py new file mode 100644 index 000000000..994eb4c01 --- /dev/null +++ b/optimum/neuron/trainers/metrics/base.py @@ -0,0 +1,74 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from ..training_args import NeuronTrainingArguments + + +if TYPE_CHECKING: + from .collector import TrainingMetricsCollector + + +class MetricUnit: + SECONDS = "s" + MILLISECONDS = "ms" + TOKENS_PER_SECOND = "tokens/s" + SAMPLES_PER_SECOND = "samples/s" + PERCENT = "%" + COUNT = "count" + TFLOPS = "TFLOP/s" + RATIO = "ratio" + NONE = "" + + +@dataclass +class MetricPlugin(ABC): + """Base class for metrics plugins. Each plugin calculates one type of metric.""" + + name: str + requires_accumulation: bool = False + depends_on: list[str] | None = None + + @abstractmethod + def is_enabled(self, args: NeuronTrainingArguments) -> bool: + """Check if this plugin should be active.""" + pass + + @abstractmethod + def calculate_realtime(self, window_stats: dict, collector: "TrainingMetricsCollector") -> dict[str, float]: + """Calculate train/ metrics from current window data.""" + pass + + @abstractmethod + def calculate_summary(self, summary_data: dict, collector: "TrainingMetricsCollector") -> dict[str, float]: + """Calculate summary/ metrics from all collected data.""" + pass + + def get_metric_names(self) -> list[str]: + """Get the metrics this plugin provides. Override for multi-metric plugins.""" + return [self.name] + + def handles_metric(self, metric_name: str) -> bool: + """Check if this plugin handles the given metric.""" + return metric_name in self.get_metric_names() + + def get_metric_units(self) -> dict[str, str]: + """Get units for each metric this plugin produces. Override in subclasses.""" + return dict.fromkeys(self.get_metric_names(), MetricUnit.NONE) diff --git a/optimum/neuron/trainers/metrics/collector.py b/optimum/neuron/trainers/metrics/collector.py new file mode 100644 index 000000000..a6d9db3e8 --- /dev/null +++ b/optimum/neuron/trainers/metrics/collector.py @@ -0,0 +1,406 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from contextlib import contextmanager +from typing import Any + +import torch +import torch_xla.runtime as xr +from neuronx_distributed.utils.model_utils import LogicalNCConfig, get_platform_lnc +from torch_neuronx.utils import get_platform_target + +from ...models.training.training_utils import get_model_param_count +from ..training_args import NeuronTrainingArguments +from .base import MetricPlugin +from .constants import HARDWARE_TFLOPS +from .efficiency import EfficiencyPlugin +from .mfu import MFUPlugin +from .registry import PluginRegistry +from .throughput import ThroughputPlugin +from .timing import ComponentTimingPlugin +from .window import MovingAverageWindow + + +class TrainingMetricsCollector: + """ + Tracks training performance metrics using a plugin system. + + Provides real-time metrics during training and summary stats at the end. + Auto-detects Trainium hardware and handles distributed training setups. + """ + + def __init__( + self, + model: Any, + training_args: NeuronTrainingArguments, + custom_plugins: list[MetricPlugin] | None = None, + ): + self.model = model + self.args = training_args + + # Set up plugins + all_plugins = self._get_default_plugins() + (custom_plugins or []) + self.registry = PluginRegistry(all_plugins) + self.registry.validate_dependencies() + + # Check if any metrics are enabled + self.enabled = any(plugin.is_enabled(training_args) for plugin in all_plugins) + + if not self.enabled: + self.metric_windows = {} + self.metric_start_times = {} + self.current_batch_data = {} + self.summary_metrics = {} + return + + self._validate_inputs() + + self.dp_size = self.args.trn_config.data_parallel_size + self.total_neuron_cores = xr.world_size() + self.platform_lnc = get_platform_lnc() + + self.model_params = None + self.num_layers = model.config.num_hidden_layers + self.num_heads = model.config.num_attention_heads + self.hidden_size = model.config.hidden_size + self.seq_length = None + self.head_dim = getattr(model.config, "head_dim", self.hidden_size // self.num_heads) + + if self.args.enable_mfu_metrics: + self.model_params = get_model_param_count(model, trainable_only=False) + + self.peak_tflops_per_core = self._detect_hardware_tflops() + self.window_size = self.args.metrics_window_size + + # Only work with enabled plugins + self.active_plugins = [p for p in all_plugins if p.is_enabled(training_args)] + + # Initialize metric windows and tracking for active plugins + self.metric_windows = {} + self.metric_start_times = {} + self.current_batch_data = {} + self.summary_metrics = {} + self.accumulating_metrics = set() + + for plugin in self.active_plugins: + for metric_name in plugin.get_metric_names(): + self.metric_windows[metric_name] = MovingAverageWindow(self.window_size) + self.metric_start_times[metric_name] = None + self.current_batch_data[metric_name] = {"tokens": 0, "samples": 0} + self.summary_metrics[metric_name] = { + "step_times": [], + "tokens_per_step": [], + "samples_per_step": [], + "step_numbers": [], + } + if plugin.requires_accumulation: + self.accumulating_metrics.add(metric_name) + + self.cycle_active = False + self.cycle_accumulators = dict.fromkeys(self.accumulating_metrics, 0.0) + self.cycle_batch_data = {"tokens": 0, "samples": 0} + self.component_start_times = dict.fromkeys(self.accumulating_metrics, None) + self.component_start_times = dict.fromkeys(self.accumulating_metrics, None) + + def _get_default_plugins(self) -> list[MetricPlugin]: + return [ + ThroughputPlugin(), + MFUPlugin(), + EfficiencyPlugin(), + ComponentTimingPlugin(), + ] + + def _validate_inputs(self): + if self.args.metrics_window_size <= 0: + raise ValueError(f"metrics_window_size must be > 0, got {self.args.metrics_window_size}") + + if self.args.enable_mfu_metrics and self.model is None: + raise ValueError("Model cannot be None when MFU metrics are enabled") + + def _detect_hardware_tflops(self) -> float: + platform_target = get_platform_target().lower() + if platform_target not in HARDWARE_TFLOPS: + raise ValueError(f"Unknown platform '{platform_target}'. We support: {list(HARDWARE_TFLOPS.keys())}") + + # Detect training precision + dtype = self._detect_training_precision() + platform_specs = HARDWARE_TFLOPS[platform_target] + + # Adjust for LNC2 if applicable + # When using LNC2, one logical core associated with one process, maps to two physical cores. + # Therefore, the peak TFLOPS per logical core is doubled. + if self.platform_lnc is LogicalNCConfig.LNC_2: + platform_specs = {k: v * 2 for k, v in platform_specs.items()} + + if dtype not in platform_specs: + raise ValueError( + f"Unknown precision '{dtype}' for platform '{platform_target}'. " + f"Supported precisions: {list(platform_specs.keys())}" + ) + + return platform_specs[dtype] + + def _detect_training_precision(self) -> str: + if self.args.bf16 or self.args.use_autocast: + return "bf16" + + # Check model dtype if available + if self.model is not None: + try: + # Get the first parameter's dtype + first_param = next(self.model.parameters()) + if first_param.dtype == torch.bfloat16: + return "bf16" + elif first_param.dtype == torch.float32: + return "fp32" + except (StopIteration, AttributeError): + pass + + # Default to fp32 if we can't determine + return "fp32" + + def _should_calculate_plugin(self, plugin: MetricPlugin, metric_type: str) -> bool: + if metric_type == "all": + return True + if plugin.name == metric_type: + return True + if hasattr(plugin, "handles_metric") and plugin.handles_metric(metric_type): + return True + return False + + def _get_plugin_window_stats(self, plugin: MetricPlugin) -> dict: + if hasattr(plugin, "get_metric_names") and len(plugin.get_metric_names()) > 1: + # Multi-metric plugins use inter-plugin communication instead + return {} + else: + # Single-metric plugins get their own window stats + metric_name = plugin.name + if metric_name in self.metric_windows: + return self.metric_windows[metric_name].get_window_stats() + return {} + + def get_metric_average_time(self, metric_name: str) -> float: + if metric_name not in self.metric_windows: + return 0.0 + window_stats = self.metric_windows[metric_name].get_window_stats() + return window_stats.get("avg_time_per_step", 0.0) + + def get_metric_window_stats(self, metric_name: str) -> dict: + if metric_name not in self.metric_windows: + return {} + return self.metric_windows[metric_name].get_window_stats() + + def get_metric_unit(self, metric_name: str) -> str: + """Get the unit for a specific metric.""" + for plugin in self.active_plugins: + if plugin.handles_metric(metric_name): + units = plugin.get_metric_units() + return units.get(metric_name, "") + return "" + + def get_all_metric_units(self) -> dict[str, str]: + """Get units for all metrics from all active plugins.""" + all_units = {} + for plugin in self.active_plugins: + all_units.update(plugin.get_metric_units()) + return all_units + + def start_gradient_accumulation_cycle(self): + """Start accumulating timing across multiple forward/backward passes.""" + if not self.enabled: + return + self.cycle_active = True + self.cycle_accumulators = dict.fromkeys(self.accumulating_metrics, 0.0) + self.cycle_batch_data = {"tokens": 0, "samples": 0} + self.component_start_times = dict.fromkeys(self.accumulating_metrics, None) + + def end_gradient_accumulation_cycle(self, step_number: int | None = None): + """Finish accumulation cycle and record the total times.""" + if not self.enabled or not self.cycle_active: + return + + for metric_name in self.accumulating_metrics: + if metric_name in self.cycle_accumulators: + self.metric_windows[metric_name].add_step( + tokens=self.cycle_batch_data["tokens"], + samples=self.cycle_batch_data["samples"], + step_time=self.cycle_accumulators[metric_name], + ) + + self.summary_metrics[metric_name]["step_times"].append(self.cycle_accumulators[metric_name]) + self.summary_metrics[metric_name]["tokens_per_step"].append(self.cycle_batch_data["tokens"]) + self.summary_metrics[metric_name]["samples_per_step"].append(self.cycle_batch_data["samples"]) + self.summary_metrics[metric_name]["step_numbers"].append(step_number or 0) + + self.cycle_active = False + self.cycle_accumulators = dict.fromkeys(self.accumulating_metrics, 0.0) + self.cycle_batch_data = {"tokens": 0, "samples": 0} + self.component_start_times = dict.fromkeys(self.accumulating_metrics, None) + + def start_metric(self, metric_name: str, inputs: dict[str, Any] | None = None): + """Start timing a metric.""" + if not self.enabled: + return + if metric_name not in self.metric_start_times: + raise ValueError(f"Unknown metric: {metric_name}. Available: {list(self.metric_start_times.keys())}") + + if self.cycle_active and metric_name in self.accumulating_metrics: + self.component_start_times[metric_name] = time.perf_counter() + if inputs is not None: + self._update_cycle_batch_data(inputs) + else: + self.metric_start_times[metric_name] = time.perf_counter() + self.current_batch_data[metric_name] = {"tokens": 0, "samples": 0} + if inputs is not None: + self._update_batch_data(metric_name, inputs) + + def update_metric_batch_data(self, metric_name: str, inputs: dict[str, Any]): + if not self.enabled or metric_name not in self.current_batch_data: + return + self._update_batch_data(metric_name, inputs) + + @contextmanager + def time_metric(self, metric_name: str, inputs: dict[str, Any] | None = None, step_number: int | None = None): + """Context manager for timing - handles start/stop automatically.""" + if not self.enabled: + yield + return + + self.start_metric(metric_name, inputs) + try: + yield + finally: + self.stop_metric(metric_name, step_number) + + def _update_batch_data(self, metric_name: str, inputs: dict[str, Any]): + batch_tokens = 0 + batch_samples = 0 + + if "input_ids" in inputs: + input_ids = inputs["input_ids"] + if isinstance(input_ids, torch.Tensor): + batch_tokens = input_ids.numel() + batch_samples = input_ids.size(0) + + if self.seq_length is None: + self.seq_length = input_ids.size(1) + + self.current_batch_data[metric_name]["tokens"] += batch_tokens + self.current_batch_data[metric_name]["samples"] += batch_samples + + def _update_cycle_batch_data(self, inputs: dict[str, Any]): + batch_tokens = 0 + batch_samples = 0 + + if "input_ids" in inputs: + input_ids = inputs["input_ids"] + if isinstance(input_ids, torch.Tensor): + batch_tokens = input_ids.numel() + batch_samples = input_ids.size(0) + + if self.seq_length is None: + self.seq_length = input_ids.size(1) + + self.cycle_batch_data["tokens"] += batch_tokens + self.cycle_batch_data["samples"] += batch_samples + + def stop_metric(self, metric_name: str, step_number: int | None = None): + """Stop timing and record the measurement.""" + if not self.enabled: + return + if metric_name not in self.metric_start_times: + raise ValueError(f"Unknown metric: {metric_name}. Available: {list(self.metric_start_times.keys())}") + + if self.cycle_active and metric_name in self.accumulating_metrics: + if self.component_start_times[metric_name] is None: + return + + elapsed_time = time.perf_counter() - self.component_start_times[metric_name] + self.cycle_accumulators[metric_name] += elapsed_time + self.component_start_times[metric_name] = None + else: + if self.metric_start_times[metric_name] is None: + return + + elapsed_time = time.perf_counter() - self.metric_start_times[metric_name] + batch_data = self.current_batch_data[metric_name] + + self.metric_windows[metric_name].add_step( + tokens=batch_data["tokens"], + samples=batch_data["samples"], + step_time=elapsed_time, + ) + + self.summary_metrics[metric_name]["step_times"].append(elapsed_time) + self.summary_metrics[metric_name]["tokens_per_step"].append(batch_data["tokens"]) + self.summary_metrics[metric_name]["samples_per_step"].append(batch_data["samples"]) + self.summary_metrics[metric_name]["step_numbers"].append(step_number or 0) + + self.metric_start_times[metric_name] = None + self.current_batch_data[metric_name] = {"tokens": 0, "samples": 0} + + def calculate_metric(self, metric_name: str) -> dict[str, float]: + """Calculate specific metric(s). Use `"all"` to get everything.""" + if not self.enabled: + return {} + + results = {} + for plugin in self.registry.get_plugins_in_dependency_order(): + if self._should_calculate_plugin(plugin, metric_name): + window_stats = self._get_plugin_window_stats(plugin) + results.update(plugin.calculate_realtime(window_stats, self)) + + return results + + def calculate_summary_metrics(self) -> dict[str, float]: + if not self.enabled: + return {} + + summary = {} + + for plugin in self.active_plugins: + # For multi-metric plugins, we need to get the right summary data + if hasattr(plugin, "get_metric_names") and len(plugin.get_metric_names()) > 1: + # For efficiency plugin, pass empty dict as it accesses summary_metrics directly + summary_data = {} + else: + # For single-metric plugins, pass their specific summary data + metric_name = plugin.name + summary_data = self.summary_metrics.get(metric_name, {}) + + summary.update(plugin.calculate_summary(summary_data, self)) + + return summary + + def reset_window(self): + if not self.enabled: + return + for metric_name in self.metric_windows: + self.metric_windows[metric_name].clear() + self.metric_start_times[metric_name] = None + self.current_batch_data[metric_name] = {"tokens": 0, "samples": 0} + + def reset_all_metrics(self): + if not self.enabled: + return + self.reset_window() + for metric_name in self.summary_metrics: + self.summary_metrics[metric_name] = { + "step_times": [], + "tokens_per_step": [], + "samples_per_step": [], + "step_numbers": [], + } diff --git a/optimum/neuron/trainers/metrics/constants.py b/optimum/neuron/trainers/metrics/constants.py new file mode 100644 index 000000000..a82aee5c8 --- /dev/null +++ b/optimum/neuron/trainers/metrics/constants.py @@ -0,0 +1,42 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class MetricNames: + """Names for all the metrics we track during training.""" + + THROUGHPUT = "throughput" + MFU = "mfu" + EFFICIENCY = "efficiency" + + # Component timing metrics + FORWARD_PASS = "forward_pass" + BACKWARD_PASS = "backward_pass" + OPTIMIZER_STEP = "optimizer_step" + TOTAL_STEP = "total_step" + + +# Specs for Trainium 1 and 2 can be found here: +# https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/trainium2.html#compute +HARDWARE_TFLOPS = { + "trn1": { + "fp32": 48 / 2, + "bf16": 191 / 2, + }, + "trn2": { + "fp32": 181 / 8, + "bf16": 667 / 8, + }, +} diff --git a/optimum/neuron/trainers/metrics/efficiency.py b/optimum/neuron/trainers/metrics/efficiency.py new file mode 100644 index 000000000..b6cecfc8d --- /dev/null +++ b/optimum/neuron/trainers/metrics/efficiency.py @@ -0,0 +1,143 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ..training_args import NeuronTrainingArguments +from .base import MetricPlugin, MetricUnit +from .constants import MetricNames + + +if TYPE_CHECKING: + from .collector import TrainingMetricsCollector + + +class EfficiencyPlugin(MetricPlugin): + """Calculates how much time is spent on useful computation vs overhead.""" + + def __init__(self): + super().__init__( + name=MetricNames.EFFICIENCY, + requires_accumulation=False, + depends_on=[ + MetricNames.FORWARD_PASS, + MetricNames.BACKWARD_PASS, + MetricNames.OPTIMIZER_STEP, + MetricNames.TOTAL_STEP, + ], + ) + + def is_enabled(self, args: NeuronTrainingArguments) -> bool: + return args.enable_efficiency_metrics + + def calculate_realtime(self, window_stats: dict, collector: "TrainingMetricsCollector") -> dict[str, float]: + """Efficiency = compute time / total time. Shows how much time we spend doing useful work.""" + # Get timing data from other plugins + forward_time = collector.get_metric_average_time(MetricNames.FORWARD_PASS) + backward_time = collector.get_metric_average_time(MetricNames.BACKWARD_PASS) + optimizer_time = collector.get_metric_average_time(MetricNames.OPTIMIZER_STEP) + total_time = collector.get_metric_average_time(MetricNames.TOTAL_STEP) + + if total_time <= 0: + return {} + + compute_time = forward_time + backward_time + optimizer_time + efficiency_pct = (compute_time / total_time) * 100 + + # Break down by component + forward_pct = (forward_time / total_time) * 100 + backward_pct = (backward_time / total_time) * 100 + optimizer_pct = (optimizer_time / total_time) * 100 + overhead_pct = 100 - efficiency_pct # Communication, data loading, etc. + + return { + "train/efficiency": round(efficiency_pct, 2), + "train/forward_time_percent": round(forward_pct, 2), + "train/backward_time_percent": round(backward_pct, 2), + "train/optimizer_time_percent": round(optimizer_pct, 2), + "train/overhead_time_percent": round(overhead_pct, 2), + } + + def calculate_summary(self, summary_data: dict, collector: "TrainingMetricsCollector") -> dict[str, float]: + """Calculate average efficiency over the entire training run.""" + # Get all the step times from component metrics + forward_data = collector.summary_metrics.get(MetricNames.FORWARD_PASS, {}) + backward_data = collector.summary_metrics.get(MetricNames.BACKWARD_PASS, {}) + optimizer_data = collector.summary_metrics.get(MetricNames.OPTIMIZER_STEP, {}) + total_data = collector.summary_metrics.get(MetricNames.TOTAL_STEP, {}) + + forward_times = forward_data.get("step_times", []) + backward_times = backward_data.get("step_times", []) + optimizer_times = optimizer_data.get("step_times", []) + total_times = total_data.get("step_times", []) + + if not all([forward_times, backward_times, optimizer_times, total_times]): + return {} + + min_steps = min(len(forward_times), len(backward_times), len(optimizer_times), len(total_times)) + + efficiency_values = [] + forward_pcts = [] + backward_pcts = [] + optimizer_pcts = [] + overhead_pcts = [] + + for i in range(min_steps): + total_time = total_times[i] + if total_time <= 0: + continue + + forward_time = forward_times[i] + backward_time = backward_times[i] + optimizer_time = optimizer_times[i] + + compute_time = forward_time + backward_time + optimizer_time + efficiency = (compute_time / total_time) * 100 + + efficiency_values.append(efficiency) + forward_pcts.append((forward_time / total_time) * 100) + backward_pcts.append((backward_time / total_time) * 100) + optimizer_pcts.append((optimizer_time / total_time) * 100) + overhead_pcts.append(100 - efficiency) + + if not efficiency_values: + return {} + + def avg(values): + return round(sum(values) / len(values), 2) + + return { + "summary/efficiency_avg": avg(efficiency_values), + "summary/forward_time_percent_avg": avg(forward_pcts), + "summary/backward_time_percent_avg": avg(backward_pcts), + "summary/optimizer_time_percent_avg": avg(optimizer_pcts), + "summary/overhead_time_percent_avg": avg(overhead_pcts), + } + + def get_metric_units(self) -> dict[str, str]: + return { + "train/efficiency": MetricUnit.PERCENT, + "train/forward_time_percent": MetricUnit.PERCENT, + "train/backward_time_percent": MetricUnit.PERCENT, + "train/optimizer_time_percent": MetricUnit.PERCENT, + "train/overhead_time_percent": MetricUnit.PERCENT, + "summary/efficiency_avg": MetricUnit.PERCENT, + "summary/forward_time_percent_avg": MetricUnit.PERCENT, + "summary/backward_time_percent_avg": MetricUnit.PERCENT, + "summary/optimizer_time_percent_avg": MetricUnit.PERCENT, + "summary/overhead_time_percent_avg": MetricUnit.PERCENT, + } diff --git a/optimum/neuron/trainers/metrics/mfu.py b/optimum/neuron/trainers/metrics/mfu.py new file mode 100644 index 000000000..b4bbc90ab --- /dev/null +++ b/optimum/neuron/trainers/metrics/mfu.py @@ -0,0 +1,98 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ..training_args import NeuronTrainingArguments +from .base import MetricPlugin, MetricUnit +from .constants import MetricNames + + +if TYPE_CHECKING: + from .collector import TrainingMetricsCollector + + +class MFUPlugin(MetricPlugin): + """Calculates Model FLOPS Utilization - how efficiently we're using the hardware.""" + + def __init__(self): + super().__init__(name=MetricNames.MFU, requires_accumulation=False) + + def is_enabled(self, args: NeuronTrainingArguments) -> bool: + return args.enable_mfu_metrics + + def _compute_mfu(self, tokens: int, time: float, collector: "TrainingMetricsCollector") -> float: + """ + Compute the system-wide MFU percentage. + + Refer to the PaLM paper Appendix B (page 66) for the MFU formula: + https://arxiv.org/pdf/2204.02311 + """ + if collector.seq_length is None: + raise ValueError("Sequence length must be set in the collector to calculate MFU.") + + N = collector.model_params + L, H, Q, T = collector.num_layers, collector.num_heads, collector.head_dim, collector.seq_length + flops_per_token = 6 * N + 12 * L * H * Q * T + + # System-wide MFU calculation + system_tokens = tokens * collector.dp_size # Scale by data parallel size + system_flops_per_iter = flops_per_token * system_tokens + system_actual_flops_per_sec = system_flops_per_iter / time + system_peak_flops_per_sec = collector.peak_tflops_per_core * collector.total_neuron_cores * 1e12 + system_mfu_pct = (system_actual_flops_per_sec / system_peak_flops_per_sec) * 100 + return system_mfu_pct + + def calculate_realtime(self, window_stats: dict, collector: "TrainingMetricsCollector") -> dict[str, float]: + """MFU = actual FLOPS / peak FLOPS as a percentage.""" + if ( + not window_stats + or collector.model_params is None + or window_stats.get("total_time", 0) <= 0 + or window_stats.get("total_tokens", 0) == 0 + ): + return {} + + total_tokens = window_stats["total_tokens"] # Per-core tokens + total_time = window_stats["total_time"] + system_mfu_pct = self._compute_mfu(total_tokens, total_time, collector) + return {"train/mfu": round(system_mfu_pct, 2)} + + def calculate_summary(self, summary_data: dict, collector: "TrainingMetricsCollector") -> dict[str, float]: + """Average MFU over the entire training run.""" + step_times = summary_data.get("step_times", []) + tokens_per_step = summary_data.get("tokens_per_step", []) + + if not step_times or collector.model_params is None: + return {} + + mfu_values = [] + for tokens, time in zip(tokens_per_step, step_times): + if time > 0 and tokens > 0: + system_mfu_pct = self._compute_mfu(tokens, time, collector) + mfu_values.append(system_mfu_pct) + + if mfu_values: + return {"summary/mfu_avg": sum(mfu_values) / len(mfu_values)} + + return {} + + def get_metric_units(self) -> dict[str, str]: + return { + "train/mfu": MetricUnit.PERCENT, + "summary/mfu_avg": MetricUnit.PERCENT, + } diff --git a/optimum/neuron/trainers/metrics/registry.py b/optimum/neuron/trainers/metrics/registry.py new file mode 100644 index 000000000..2f688181a --- /dev/null +++ b/optimum/neuron/trainers/metrics/registry.py @@ -0,0 +1,53 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import MetricPlugin + + +class PluginRegistry: + """Manages plugin discovery and provides explicit inter-plugin communication.""" + + def __init__(self, plugins: list[MetricPlugin]): + self.plugins = {p.name: p for p in plugins} + self.metric_to_plugin = {} + + # Build reverse lookup: metric_name -> plugin + for plugin in plugins: + for metric_name in plugin.get_metric_names(): + self.metric_to_plugin[metric_name] = plugin + + def get_plugin(self, plugin_name: str) -> MetricPlugin | None: + """Get plugin by name.""" + return self.plugins.get(plugin_name) + + def get_plugin_for_metric(self, metric_name: str) -> MetricPlugin | None: + """Get the plugin that handles a specific metric.""" + return self.metric_to_plugin.get(metric_name) + + def validate_dependencies(self) -> None: + """Make sure all plugin dependencies are satisfied.""" + for plugin in self.plugins.values(): + if not plugin.depends_on: + continue + + for dep_metric in plugin.depends_on: + if dep_metric not in self.metric_to_plugin: + raise ValueError(f"Plugin '{plugin.name}' needs metric '{dep_metric}', but no plugin provides it") + + def get_plugins_in_dependency_order(self) -> list[MetricPlugin]: + """Sort plugins so dependencies come first.""" + independent = [p for p in self.plugins.values() if not p.depends_on] + dependent = [p for p in self.plugins.values() if p.depends_on] + return independent + dependent diff --git a/optimum/neuron/trainers/metrics/throughput.py b/optimum/neuron/trainers/metrics/throughput.py new file mode 100644 index 000000000..43b55a0bc --- /dev/null +++ b/optimum/neuron/trainers/metrics/throughput.py @@ -0,0 +1,88 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ..training_args import NeuronTrainingArguments +from .base import MetricPlugin, MetricUnit +from .constants import MetricNames + + +if TYPE_CHECKING: + from .collector import TrainingMetricsCollector + + +class ThroughputPlugin(MetricPlugin): + """Calculates how many tokens/samples we process per second.""" + + def __init__(self): + super().__init__(name=MetricNames.THROUGHPUT, requires_accumulation=False) + + def is_enabled(self, args: NeuronTrainingArguments) -> bool: + return args.enable_throughput_metrics + + def calculate_realtime(self, window_stats: dict, collector: "TrainingMetricsCollector") -> dict[str, float]: + """Tokens per second across all devices.""" + if not window_stats or window_stats.get("total_time", 0) <= 0: + return {} + + total_tokens = window_stats["total_tokens"] + total_time = window_stats["total_time"] + + metrics = {} + + if total_tokens > 0: + local_tps = total_tokens / total_time + global_tps = local_tps * collector.dp_size # Scale by number of data parallel workers + metrics["train/tokens_per_sec"] = global_tps + + metrics["train/step_time"] = window_stats["avg_time_per_step"] + return metrics + + def calculate_summary(self, summary_data: dict, collector: "TrainingMetricsCollector") -> dict[str, float]: + """Average throughput over the entire training run.""" + step_times = summary_data.get("step_times", []) + tokens_per_step = summary_data.get("tokens_per_step", []) + + if not step_times: + return {} + + # Calculate tokens/sec for each step + local_tps_values = [tokens / time if time > 0 else 0 for tokens, time in zip(tokens_per_step, step_times)] + global_tps_values = [rate * collector.dp_size for rate in local_tps_values] + + summary = {} + if global_tps_values: + summary["summary/tokens_per_sec_avg"] = sum(global_tps_values) / len(global_tps_values) + + summary.update( + { + "summary/total_steps": len(step_times), + "summary/total_tokens_processed": sum(tokens_per_step) * collector.dp_size, + } + ) + + return summary + + def get_metric_units(self) -> dict[str, str]: + return { + "train/tokens_per_sec": MetricUnit.TOKENS_PER_SECOND, + "train/step_time": MetricUnit.SECONDS, + "summary/tokens_per_sec_avg": MetricUnit.TOKENS_PER_SECOND, + "summary/total_steps": MetricUnit.COUNT, + "summary/total_tokens_processed": MetricUnit.COUNT, + } diff --git a/optimum/neuron/trainers/metrics/timing.py b/optimum/neuron/trainers/metrics/timing.py new file mode 100644 index 000000000..3bda98250 --- /dev/null +++ b/optimum/neuron/trainers/metrics/timing.py @@ -0,0 +1,63 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ..training_args import NeuronTrainingArguments +from .base import MetricPlugin, MetricUnit +from .constants import MetricNames + + +if TYPE_CHECKING: + from .collector import TrainingMetricsCollector + + +class ComponentTimingPlugin(MetricPlugin): + """Tracks individual component times (forward, backward, optimizer, total).""" + + def __init__(self): + super().__init__( + name="component_timing", + requires_accumulation=True, # forward/backward need accumulation across gradient steps + ) + + def is_enabled(self, args: NeuronTrainingArguments) -> bool: + return True # Always needed for efficiency calculations + + def get_metric_names(self) -> list[str]: + return [ + MetricNames.FORWARD_PASS, + MetricNames.BACKWARD_PASS, + MetricNames.OPTIMIZER_STEP, + MetricNames.TOTAL_STEP, + ] + + def calculate_realtime(self, window_stats: dict, collector: "TrainingMetricsCollector") -> dict[str, float]: + """This plugin just provides timing data to other plugins.""" + return {} + + def calculate_summary(self, summary_data: dict, collector: "TrainingMetricsCollector") -> dict[str, float]: + """This plugin just provides timing data to other plugins.""" + return {} + + def get_metric_units(self) -> dict[str, str]: + return { + MetricNames.FORWARD_PASS: MetricUnit.SECONDS, + MetricNames.BACKWARD_PASS: MetricUnit.SECONDS, + MetricNames.OPTIMIZER_STEP: MetricUnit.SECONDS, + MetricNames.TOTAL_STEP: MetricUnit.SECONDS, + } diff --git a/optimum/neuron/trainers/metrics/window.py b/optimum/neuron/trainers/metrics/window.py new file mode 100644 index 000000000..2a7fc89bb --- /dev/null +++ b/optimum/neuron/trainers/metrics/window.py @@ -0,0 +1,71 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import deque + + +class MovingAverageWindow: + """ + A moving average window for tracking metrics over a sliding window. + + Maintains separate deques for tokens, samples, and timing information, + allowing for stable moving average calculations. + + It is storing metrics per-core and is agnostic of any distributed setup. + Scaling of these metrics for distributed training should be handled by `MetricPlugin`. + """ + + def __init__(self, window_size: int): + self.window_size = window_size + self.tokens_per_step = deque(maxlen=window_size) + self.samples_per_step = deque(maxlen=window_size) + self.step_times = deque(maxlen=window_size) + + def add_step(self, tokens: int, samples: int, step_time: float): + self.tokens_per_step.append(tokens) + self.samples_per_step.append(samples) + self.step_times.append(step_time) + + def get_window_stats(self) -> dict[str, float]: + if not self.step_times: + return {} + + total_tokens = sum(self.tokens_per_step) + total_samples = sum(self.samples_per_step) + total_time = sum(self.step_times) + window_steps = len(self.step_times) + + return { + "total_tokens": total_tokens, + "total_samples": total_samples, + "total_time": total_time, + "window_steps": window_steps, + "avg_tokens_per_step": total_tokens / window_steps if window_steps > 0 else 0, + "avg_samples_per_step": total_samples / window_steps if window_steps > 0 else 0, + "avg_time_per_step": total_time / window_steps if window_steps > 0 else 0, + } + + def clear(self): + self.tokens_per_step.clear() + self.samples_per_step.clear() + self.step_times.clear() + + @property + def is_full(self) -> bool: + return len(self.step_times) == self.window_size + + @property + def size(self) -> int: + return len(self.step_times) diff --git a/optimum/neuron/trainers/training_args.py b/optimum/neuron/trainers/training_args.py index e692d38e3..9b26b923e 100644 --- a/optimum/neuron/trainers/training_args.py +++ b/optimum/neuron/trainers/training_args.py @@ -427,6 +427,43 @@ class NeuronTrainingArguments: }, ) + # Training metrics configuration + enable_throughput_metrics: bool = field( + default=True, + metadata={ + "help": ( + "Whether to calculate and log throughput metrics (tokens/sec, samples/sec, both general and per-neuron-core)." + ) + }, + ) + enable_mfu_metrics: bool = field( + default=True, + metadata={ + "help": ( + "Whether to calculate and log Model FLOPs Utilization (MFU) metrics. " + "This requires additional computation and is disabled by default." + ) + }, + ) + enable_efficiency_metrics: bool = field( + default=True, + metadata={ + "help": ( + "Whether to calculate and log training efficiency metrics. " + "This requires additional computation and is disabled by default." + ) + }, + ) + metrics_window_size: int = field( + default=50, + metadata={ + "help": ( + "Size of the moving average window for metrics calculation. " + "Larger windows provide more stable metrics but react slower to changes." + ) + }, + ) + def __post_init__(self): # Set the verbosity so that each process logs according to its rank. log_level = self.get_process_log_level() diff --git a/optimum/neuron/trainers/transformers.py b/optimum/neuron/trainers/transformers.py index 8e9ddcbf8..538fd45a2 100644 --- a/optimum/neuron/trainers/transformers.py +++ b/optimum/neuron/trainers/transformers.py @@ -14,10 +14,12 @@ # limitations under the License. import inspect +import json import math import os import re import sys +from datetime import datetime from functools import partial from pathlib import Path from typing import Any, Callable, Iterator, Type @@ -97,6 +99,7 @@ ) from ..utils.import_utils import is_peft_available from ..utils.misc import is_main_worker, is_precompilation +from .metrics import TrainingMetricsCollector from .training_args import NeuronTrainingArguments from .utils import XLAPrefetchIterator @@ -260,6 +263,9 @@ def __init__( ], ) + # Initialize training metrics collector - auto-detects if metrics are enabled + self.metrics_collector = TrainingMetricsCollector(self.model, args) + if isinstance(self.model, NeuronPeftModel) and self.args.label_names is None: logger.warning( f"No label_names provided for model class `{self.model.__class__.__name__}`." @@ -942,8 +948,10 @@ def train_step( manager = self.autocast_smart_context_manager() if isinstance(model, NxDPPModel): - with manager: - loss = model.run_train(**inputs) + # Time forward pass for pipeline parallel models (run_train includes both forward and backward for PP) + with self.metrics_collector.time_metric("forward_pass", inputs=inputs): + with manager: + loss = model.run_train(**inputs) # When using pipeline parallelism, the loss is only computed on the last stage. # So we set the loss to zero on other stages. @@ -954,8 +962,10 @@ def train_step( if num_items_in_batch is not None: inputs = dict(**inputs, reduction="sum") - with manager: - outputs = model(**inputs) + # Time forward pass + with self.metrics_collector.time_metric("forward_pass", inputs=inputs): + with manager: + outputs = model(**inputs) if isinstance(outputs, dict) and "loss" not in outputs: raise ValueError( @@ -970,8 +980,9 @@ def train_step( else: loss = loss / self.args.gradient_accumulation_steps - # Backward pass - self.accelerator.backward(loss) + # Time backward pass + with self.metrics_collector.time_metric("backward_pass", inputs=inputs): + self.accelerator.backward(loss) return loss @@ -998,6 +1009,18 @@ def maybe_log_train_step_metrics(self): reduced_loss = reduced_loss.detach() self.running_loss.zero_() + # Calculate metrics here to avoid complications with closure execution + metrics = {} + if self.metrics_collector.enabled: + try: + metrics = self.metrics_collector.calculate_metric("all") + # Reset the metrics window after calculation + self.metrics_collector.reset_window() + except Exception as e: + # Log error but don't fail training + logger.warning(f"Failed to calculate training metrics: {e}") + metrics = {} + def log_closure(): # We need to check that self.state.global_step > self._globalstep_last_logged because if two # closures are added in a row (which can happen at the end of the training), then it will fail the @@ -1016,6 +1039,10 @@ def log_closure(): if isinstance(self.grad_norm, torch.Tensor) else self.grad_norm ) + + # Add metrics to the logs + logs.update(metrics) + self.log(logs) self.global_step_last_logged = self.state.global_step @@ -1094,6 +1121,13 @@ def train( prefetch_size=args.dataloader_prefetch_size, ) + # Start gradient accumulation cycle and cycle-level timing + # Note: throughput and total_step span the entire gradient accumulation cycle, + # so we use start/stop rather than context manager for better clarity + self.metrics_collector.start_gradient_accumulation_cycle() + self.metrics_collector.start_metric("throughput") + self.metrics_collector.start_metric("total_step") + for inputs in batch_samples: xm.mark_step() step += 1 @@ -1102,7 +1136,11 @@ def train( if step % args.gradient_accumulation_steps == 0: self.control = self.callback_handler.on_step_begin(args, self.state, self.control) - loss_step = self.train_step(self.model, inputs, num_items_in_batch=num_items_in_batch) + self.metrics_collector.update_metric_batch_data("throughput", inputs) + + with self.metrics_collector.time_metric("mfu", inputs=inputs): + loss_step = self.train_step(self.model, inputs, num_items_in_batch=num_items_in_batch) + self.running_loss += loss_step.detach() if do_sync_step: @@ -1112,8 +1150,10 @@ def train( self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control) - self.optimizer.step() - self.grad_norm = self.optimizer.grad_norm + # Time optimizer step + with self.metrics_collector.time_metric("optimizer_step", inputs=inputs): + self.optimizer.step() + self.grad_norm = self.optimizer.grad_norm self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) @@ -1127,6 +1167,9 @@ def train( self.state.epoch = epoch + (step + 1) / steps_in_epoch self.control = self.callback_handler.on_step_end(args, self.state, self.control) xm.mark_step() + self.metrics_collector.stop_metric("throughput") + self.metrics_collector.stop_metric("total_step") + self.metrics_collector.end_gradient_accumulation_cycle(step_number=self.state.global_step) else: self.accelerator.gradient_state.sync_gradients = False self.control = self.callback_handler.on_substep_end(args, self.state, self.control) @@ -1160,6 +1203,9 @@ def train( if self.control.should_training_stop: break + # Report and save training summary metrics + self.report_and_save_summary_metrics() + logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") self.control = self.callback_handler.on_train_end(args, self.state, self.control) @@ -1238,6 +1284,65 @@ def log(self, logs: dict[str, float]) -> None: self.state.log_history.append(output) self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs) + def report_and_save_summary_metrics(self): + """Report and save comprehensive training summary metrics at the end of training.""" + try: + summary_metrics = self.metrics_collector.calculate_summary_metrics() + if not summary_metrics: + return + + # Save summary metrics to file + summary_file_path = os.path.join(self.args.output_dir, "training_summary_metrics.json") + + # Add metadata to the summary + summary_with_metadata = { + "metadata": { + "timestamp": datetime.now().isoformat(), + "total_training_steps": self.state.global_step, + "total_epochs": self.state.epoch, + "model_name": getattr(self.model, "_name_or_path", "unknown"), + "gradient_accumulation_steps": self.args.gradient_accumulation_steps, + "per_device_train_batch_size": self.args.per_device_train_batch_size, + "learning_rate": self.args.learning_rate, + "tensor_parallel_size": getattr(self.args, "tensor_parallel_size", 1), + "pipeline_parallel_size": getattr(self.args, "pipeline_parallel_size", 1), + "total_neuron_cores": self.metrics_collector.total_neuron_cores, + }, + "metrics": summary_metrics, + } + + with open(summary_file_path, "w") as f: + json.dump(summary_with_metadata, f, indent=2) + + logger.info("=" * 80) + logger.info("TRAINING SUMMARY METRICS") + logger.info("=" * 80) + + # Group and format metrics for better readability + for metric_name, value in summary_metrics.items(): + if isinstance(value, float): + if "time" in metric_name: + logger.info(f"{metric_name}: {value:.4f}s") + elif "per_sec" in metric_name: + logger.info(f"{metric_name}: {value:.2f}") + elif ( + "mfu" in metric_name + or "efficiency" in metric_name + or "consistency" in metric_name + or "percent" in metric_name + ): + logger.info(f"{metric_name}: {value:.2f}%") + else: + logger.info(f"{metric_name}: {value:.2f}") + else: + logger.info(f"{metric_name}: {value}") + logger.info("=" * 80) + logger.info(f"Summary metrics saved to: {summary_file_path}") + logger.info("=" * 80) + + except Exception as e: + logger.warning(f"Failed to calculate training summary metrics: {e}") + def _save_checkpoint(self): # Save model checkpoint checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" diff --git a/tests/training/test_metrics.py b/tests/training/test_metrics.py new file mode 100644 index 000000000..654bbb871 --- /dev/null +++ b/tests/training/test_metrics.py @@ -0,0 +1,179 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import pytest +import torch + +from optimum.neuron.models.training.llama.modeling_llama import LlamaForCausalLM +from optimum.neuron.trainers.metrics import TrainingMetricsCollector +from optimum.neuron.trainers.training_args import NeuronTrainingArguments +from optimum.neuron.utils.testing_utils import is_trainium_test + +from .distributed_utils import run_distributed_test +from .utils import MODEL_NAME + + +@pytest.mark.parametrize( + "world_size,tp_size,pp_size", + [ + (8, 1, 1), + (32, 8, 1), + (32, 1, 4), + (32, 8, 4), + ], + ids=["8_1_1", "32_8_1", "32_1_4", "32_8_4"], +) +@is_trainium_test +def test_metrics_distributed_correctness(world_size, tp_size, pp_size, tmpdir): + def _test_metrics_computation(): + args = NeuronTrainingArguments( + output_dir=tmpdir, + enable_throughput_metrics=True, + enable_mfu_metrics=True, + enable_efficiency_metrics=True, + metrics_window_size=3, + ) + + model = LlamaForCausalLM.from_pretrained(MODEL_NAME, args.trn_config) + collector = TrainingMetricsCollector(model, args) + + # Test unit system first + units = collector.get_all_metric_units() + assert units["train/tokens_per_sec"] == "tokens/s" + assert units["train/mfu"] == "%" + assert units["train/step_time"] == "s" + assert units["train/efficiency"] == "%" + assert units["train/forward_time_percent"] == "%" + assert units["train/backward_time_percent"] == "%" + assert units["train/optimizer_time_percent"] == "%" + assert units["train/overhead_time_percent"] == "%" + + inputs = {"input_ids": torch.randint(0, 1000, (4, 32))} + inputs["labels"] = inputs["input_ids"].clone() + + # Test accumulation cycle with all component timings + collector.start_gradient_accumulation_cycle() + collector.start_metric("total_step") + + # Simulate gradient_accumulation_steps=2 + for _ in range(2): + collector.start_metric("forward_pass", inputs) + time.sleep(0.005) + collector.stop_metric("forward_pass") + + collector.start_metric("backward_pass", inputs) + time.sleep(0.005) + collector.stop_metric("backward_pass") + + collector.start_metric("optimizer_step", inputs) + time.sleep(0.003) + collector.stop_metric("optimizer_step") + + collector.stop_metric("total_step") + collector.end_gradient_accumulation_cycle(step_number=1) + + # Test throughput and MFU metrics together + for step in range(3): + collector.start_metric("throughput", inputs) + collector.start_metric("mfu", inputs) + start_time = time.perf_counter() + time.sleep(0.02) + elapsed = time.perf_counter() - start_time + collector.stop_metric("throughput", step_number=step) + collector.stop_metric("mfu", step_number=step) + + window_stats = collector.metric_windows["throughput"].get_window_stats() + + assert window_stats["total_tokens"] == 128 * (step + 1) + assert abs(window_stats["total_time"] - elapsed) < 0.05 + + # Validate throughput computation + throughput_metrics = collector.calculate_metric("throughput") + assert "train/tokens_per_sec" in throughput_metrics + + window_stats = collector.metric_windows["throughput"].get_window_stats() + expected_local_rate = window_stats["total_tokens"] / window_stats["total_time"] + expected_global_rate = expected_local_rate * collector.dp_size + actual_global_rate = throughput_metrics["train/tokens_per_sec"] + + relative_error = abs(actual_global_rate - expected_global_rate) / expected_global_rate + assert relative_error < 0.05, ( + f"Throughput calculation failed: expected {expected_global_rate}, got {actual_global_rate} " + f"(relative error={relative_error:.3f})" + ) + + # Validate MFU computation + mfu_metrics = collector.calculate_metric("mfu") + assert mfu_metrics != {}, "MFU metrics should not be empty" + + total_tokens = window_stats["total_tokens"] + total_time = window_stats["total_time"] + + # Use exact same formula as mfu.py implementation + assert collector.model_params is not None, "Model params should be set in the collector" + assert collector.seq_length is not None, "Sequence length should be set in the collector" + + N = collector.model_params + L, H, Q, T = collector.num_layers, collector.num_heads, collector.head_dim, collector.seq_length + flops_per_token = 6 * N + 12 * L * H * Q * T + + system_tokens = total_tokens * collector.dp_size + system_flops_per_iter = flops_per_token * system_tokens + system_actual_flops_per_sec = system_flops_per_iter / total_time + system_peak_flops_per_sec = collector.peak_tflops_per_core * collector.total_neuron_cores * 1e12 + expected_system_mfu = (system_actual_flops_per_sec / system_peak_flops_per_sec) * 100 + + system_mfu_diff = abs(mfu_metrics["train/mfu"] - round(expected_system_mfu, 2)) + assert system_mfu_diff < 0.1, ( + f"System MFU calculation failed: expected {round(expected_system_mfu, 2):.2f}%, got " + f"{mfu_metrics['train/mfu']:.2f}% (diff={system_mfu_diff:.3f}, cores={collector.total_neuron_cores})" + ) + + # Validate efficiency computation (test the total_step timing fix) + efficiency_metrics = collector.calculate_metric("efficiency") + assert efficiency_metrics != {}, "Efficiency metrics should not be empty" + + forward_pct = efficiency_metrics.get("train/forward_time_percent", 0) + backward_pct = efficiency_metrics.get("train/backward_time_percent", 0) + optimizer_pct = efficiency_metrics.get("train/optimizer_time_percent", 0) + overhead_pct = efficiency_metrics.get("train/overhead_time_percent", 0) + total_efficiency = efficiency_metrics.get("train/efficiency", 0) + + # Test that percentages are reasonable + assert forward_pct > 0, "Forward time percentage should be > 0" + assert backward_pct > 0, "Backward time percentage should be > 0" + assert optimizer_pct > 0, "Optimizer time percentage should be > 0" + + # Test that efficiency calculation is correct + # It is not completely exact due to rounding, but should be very close. + assert abs((forward_pct + backward_pct + optimizer_pct) - total_efficiency) < 0.02 + assert abs((total_efficiency + overhead_pct) - 100.0) < 0.02 + + # Validate summary metrics + summary_metrics = collector.calculate_summary_metrics() + throughput_data = collector.summary_metrics["throughput"] + if throughput_data["step_times"]: + manual_rates = [ + (tokens / time * collector.dp_size) + for tokens, time in zip(throughput_data["tokens_per_step"], throughput_data["step_times"]) + if time > 0 + ] + expected_summary = sum(manual_rates) / len(manual_rates) + actual_summary = summary_metrics["summary/tokens_per_sec_avg"] + assert abs(actual_summary - expected_summary) < 0.01 + + run_distributed_test(_test_metrics_computation, world_size, tp_size, pp_size)