Skip to content

Commit 1dfe4da

Browse files
Metrics for training (#982)
# What does this PR do? This PR introduces a training metrics collection system. - Plugin-based architecture: Modular system with `ThroughputPlugin`, `MFUPlugin`, `EfficiencyPlugin`, and `ComponentTimingPlugin` - Moving window statistics: Configurable window size for real-time metrics calculation - Hardware detection: Automatic TRN1/TRN2 platform detection with correct peak FLOPS values Core Metrics Implementation - Throughput metrics: Tokens/second calculation with proper data parallel scaling - Model FLOPS Utilization (MFU): System-wide MFU calculation using PaLM paper formula: 6*N + 12*L*H*Q*T FLOPS per token - Training efficiency: Breakdown of time spent on forward/backward/optimizer vs overhead - Component timing: Individual timing for forward pass, backward pass, optimizer step, and total step
1 parent 99ff466 commit 1dfe4da

File tree

13 files changed

+1385
-9
lines changed

13 files changed

+1385
-9
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# coding=utf-8
2+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from .collector import TrainingMetricsCollector
17+
from .window import MovingAverageWindow
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# coding=utf-8
2+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from __future__ import annotations
17+
18+
from abc import ABC, abstractmethod
19+
from dataclasses import dataclass
20+
from typing import TYPE_CHECKING
21+
22+
from ..training_args import NeuronTrainingArguments
23+
24+
25+
if TYPE_CHECKING:
26+
from .collector import TrainingMetricsCollector
27+
28+
29+
class MetricUnit:
30+
SECONDS = "s"
31+
MILLISECONDS = "ms"
32+
TOKENS_PER_SECOND = "tokens/s"
33+
SAMPLES_PER_SECOND = "samples/s"
34+
PERCENT = "%"
35+
COUNT = "count"
36+
TFLOPS = "TFLOP/s"
37+
RATIO = "ratio"
38+
NONE = ""
39+
40+
41+
@dataclass
42+
class MetricPlugin(ABC):
43+
"""Base class for metrics plugins. Each plugin calculates one type of metric."""
44+
45+
name: str
46+
requires_accumulation: bool = False
47+
depends_on: list[str] | None = None
48+
49+
@abstractmethod
50+
def is_enabled(self, args: NeuronTrainingArguments) -> bool:
51+
"""Check if this plugin should be active."""
52+
pass
53+
54+
@abstractmethod
55+
def calculate_realtime(self, window_stats: dict, collector: "TrainingMetricsCollector") -> dict[str, float]:
56+
"""Calculate train/ metrics from current window data."""
57+
pass
58+
59+
@abstractmethod
60+
def calculate_summary(self, summary_data: dict, collector: "TrainingMetricsCollector") -> dict[str, float]:
61+
"""Calculate summary/ metrics from all collected data."""
62+
pass
63+
64+
def get_metric_names(self) -> list[str]:
65+
"""Get the metrics this plugin provides. Override for multi-metric plugins."""
66+
return [self.name]
67+
68+
def handles_metric(self, metric_name: str) -> bool:
69+
"""Check if this plugin handles the given metric."""
70+
return metric_name in self.get_metric_names()
71+
72+
def get_metric_units(self) -> dict[str, str]:
73+
"""Get units for each metric this plugin produces. Override in subclasses."""
74+
return dict.fromkeys(self.get_metric_names(), MetricUnit.NONE)

0 commit comments

Comments
 (0)