Skip to content
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
9625ce6
feat: metrics for training
michaelbenayoun Sep 29, 2025
c212f8c
feat: metrics for training
michaelbenayoun Sep 29, 2025
833888e
feat: metrics for training
michaelbenayoun Sep 29, 2025
9dd04c5
feat: metrics for training
michaelbenayoun Sep 29, 2025
f85638a
feat: metrics for training
michaelbenayoun Sep 29, 2025
dcd5883
feat: metrics for training
michaelbenayoun Sep 29, 2025
3d1fd56
feat: add summary for metrics
michaelbenayoun Sep 29, 2025
4df2d90
feat: add summary for metrics
michaelbenayoun Sep 29, 2025
ea5edf0
feat: effective throughput
michaelbenayoun Sep 29, 2025
cb0a5b3
feat: polishing the metrics classes
michaelbenayoun Sep 29, 2025
879d75a
feat: optimization of metric computation
michaelbenayoun Sep 29, 2025
e99d37d
test: add tests for metrics, wip
michaelbenayoun Sep 29, 2025
902d69e
Add training efficiency
michaelbenayoun Oct 1, 2025
9ba7d53
feat: add overhead time
michaelbenayoun Oct 1, 2025
7d6ff60
feat: cleanup and keep training efficiency
michaelbenayoun Oct 1, 2025
b86dfc2
feat: remove docstrings
michaelbenayoun Oct 1, 2025
5168337
feat: remove backward compatibility code
michaelbenayoun Oct 1, 2025
f2498c1
feat: keep relevant metrics
michaelbenayoun Oct 1, 2025
14cb336
feat: add metrics breakdown
michaelbenayoun Oct 1, 2025
6e391bf
feat: handle the case where there are no metrics
michaelbenayoun Oct 1, 2025
2d8e8e0
fix: change sec to % for percents
michaelbenayoun Oct 1, 2025
090a7c5
remove test_metrics.py
michaelbenayoun Oct 3, 2025
26a0f19
test: add tests for metrics
michaelbenayoun Oct 3, 2025
7ef5ed5
test: improve metrics tests
michaelbenayoun Oct 3, 2025
08c6db4
test: improve metrics tests
michaelbenayoun Oct 3, 2025
5499f19
feat: added plugin system for metrics
michaelbenayoun Oct 3, 2025
48c620d
refactor: delete metric.py file since plugin system changes that
michaelbenayoun Oct 3, 2025
aa9821e
refactor: improve the existing base
michaelbenayoun Oct 3, 2025
8025e1b
refactor: easier imports
michaelbenayoun Oct 3, 2025
b987e04
refactor: easier imports
michaelbenayoun Oct 3, 2025
d8b38fc
test: improve metrics tests
michaelbenayoun Oct 6, 2025
a47e9b7
test: improve metrics tests
michaelbenayoun Oct 6, 2025
1975137
test: improve metrics tests
michaelbenayoun Oct 6, 2025
b4543ea
Merge branch 'main' into metrics
michaelbenayoun Oct 6, 2025
0ed9244
test: improve metrics tests
michaelbenayoun Oct 6, 2025
99e0eb5
fix: mfu computation
michaelbenayoun Oct 10, 2025
7f142e5
fix: mfu computation
michaelbenayoun Oct 10, 2025
c155c57
fix: training effiency
michaelbenayoun Oct 14, 2025
5302622
fix: add missing files
michaelbenayoun Oct 14, 2025
a7f9ea7
feat: enable metrics collection only for the rank responsible for log…
michaelbenayoun Oct 14, 2025
6c616db
wip: test metrics
michaelbenayoun Oct 15, 2025
6aca088
test: change the assert criteria to account for rounding
michaelbenayoun Oct 15, 2025
a42a095
feat: add hardware specs per dtype
michaelbenayoun Oct 15, 2025
53ea695
fix: broken import
michaelbenayoun Oct 15, 2025
564ee50
refactor: mfu computation in a function
michaelbenayoun Oct 15, 2025
4b5480e
fix: restore finetune_qwen3.sh
michaelbenayoun Oct 15, 2025
86918db
fix: divide the flops constants by 8 for trn2 and take the lnc into a…
michaelbenayoun Oct 29, 2025
681ebec
fix: add useful comment
michaelbenayoun Oct 31, 2025
ea70bb8
fix: cache dtype entry
michaelbenayoun Oct 31, 2025
3e0cdb7
Merge branch 'main' into metrics
michaelbenayoun Oct 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions optimum/neuron/trainers/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .collector import TrainingMetricsCollector
from .window import MovingAverageWindow
74 changes: 74 additions & 0 deletions optimum/neuron/trainers/metrics/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING

from ..training_args import NeuronTrainingArguments


if TYPE_CHECKING:
from .collector import TrainingMetricsCollector


class MetricUnit:
SECONDS = "s"
MILLISECONDS = "ms"
TOKENS_PER_SECOND = "tokens/s"
SAMPLES_PER_SECOND = "samples/s"
PERCENT = "%"
COUNT = "count"
TFLOPS = "TFLOP/s"
RATIO = "ratio"
NONE = ""


@dataclass
class MetricPlugin(ABC):
"""Base class for metrics plugins. Each plugin calculates one type of metric."""

name: str
requires_accumulation: bool = False
depends_on: list[str] | None = None

@abstractmethod
def is_enabled(self, args: NeuronTrainingArguments) -> bool:
"""Check if this plugin should be active."""
pass

@abstractmethod
def calculate_realtime(self, window_stats: dict, collector: "TrainingMetricsCollector") -> dict[str, float]:
"""Calculate train/ metrics from current window data."""
pass

@abstractmethod
def calculate_summary(self, summary_data: dict, collector: "TrainingMetricsCollector") -> dict[str, float]:
"""Calculate summary/ metrics from all collected data."""
pass

def get_metric_names(self) -> list[str]:
"""Get the metrics this plugin provides. Override for multi-metric plugins."""
return [self.name]

def handles_metric(self, metric_name: str) -> bool:
"""Check if this plugin handles the given metric."""
return metric_name in self.get_metric_names()

def get_metric_units(self) -> dict[str, str]:
"""Get units for each metric this plugin produces. Override in subclasses."""
return dict.fromkeys(self.get_metric_names(), MetricUnit.NONE)
Loading
Loading