|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +from __future__ import annotations |
| 17 | + |
| 18 | +from abc import ABC, abstractmethod |
| 19 | +from dataclasses import dataclass |
| 20 | +from typing import TYPE_CHECKING |
| 21 | + |
| 22 | +from ..training_args import NeuronTrainingArguments |
| 23 | + |
| 24 | + |
| 25 | +if TYPE_CHECKING: |
| 26 | + from .collector import TrainingMetricsCollector |
| 27 | + |
| 28 | + |
| 29 | +class MetricUnit: |
| 30 | + SECONDS = "s" |
| 31 | + MILLISECONDS = "ms" |
| 32 | + TOKENS_PER_SECOND = "tokens/s" |
| 33 | + SAMPLES_PER_SECOND = "samples/s" |
| 34 | + PERCENT = "%" |
| 35 | + COUNT = "count" |
| 36 | + TFLOPS = "TFLOP/s" |
| 37 | + RATIO = "ratio" |
| 38 | + NONE = "" |
| 39 | + |
| 40 | + |
| 41 | +@dataclass |
| 42 | +class MetricPlugin(ABC): |
| 43 | + """Base class for metrics plugins. Each plugin calculates one type of metric.""" |
| 44 | + |
| 45 | + name: str |
| 46 | + requires_accumulation: bool = False |
| 47 | + depends_on: list[str] | None = None |
| 48 | + |
| 49 | + @abstractmethod |
| 50 | + def is_enabled(self, args: NeuronTrainingArguments) -> bool: |
| 51 | + """Check if this plugin should be active.""" |
| 52 | + pass |
| 53 | + |
| 54 | + @abstractmethod |
| 55 | + def calculate_realtime(self, window_stats: dict, collector: "TrainingMetricsCollector") -> dict[str, float]: |
| 56 | + """Calculate train/ metrics from current window data.""" |
| 57 | + pass |
| 58 | + |
| 59 | + @abstractmethod |
| 60 | + def calculate_summary(self, summary_data: dict, collector: "TrainingMetricsCollector") -> dict[str, float]: |
| 61 | + """Calculate summary/ metrics from all collected data.""" |
| 62 | + pass |
| 63 | + |
| 64 | + def get_metric_names(self) -> list[str]: |
| 65 | + """Get the metrics this plugin provides. Override for multi-metric plugins.""" |
| 66 | + return [self.name] |
| 67 | + |
| 68 | + def handles_metric(self, metric_name: str) -> bool: |
| 69 | + """Check if this plugin handles the given metric.""" |
| 70 | + return metric_name in self.get_metric_names() |
| 71 | + |
| 72 | + def get_metric_units(self) -> dict[str, str]: |
| 73 | + """Get units for each metric this plugin produces. Override in subclasses.""" |
| 74 | + return dict.fromkeys(self.get_metric_names(), MetricUnit.NONE) |
0 commit comments