Skip to content

Commit 4dc0189

Browse files
authored
Feat/opentelemetry (#3215)
1 parent 2436203 commit 4dc0189

File tree

8 files changed

+687
-1
lines changed

8 files changed

+687
-1
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
base_model: NousResearch/Llama-3.2-1B
2+
model_type: AutoModelForCausalLM
3+
tokenizer_type: AutoTokenizer
4+
5+
load_in_4bit: true
6+
7+
datasets:
8+
- path: mhenrichsen/alpaca_2k_test
9+
type: alpaca
10+
11+
output_dir: ./outputs/opentelemetry-example
12+
13+
adapter: qlora
14+
sequence_len: 512
15+
sample_packing: false
16+
17+
lora_r: 32
18+
lora_alpha: 16
19+
lora_dropout: 0.05
20+
lora_target_linear: true
21+
22+
# OpenTelemetry Configuration
23+
use_otel_metrics: true
24+
otel_metrics_host: "localhost"
25+
otel_metrics_port: 8000
26+
27+
# Disable WandB
28+
use_wandb: false
29+
30+
gradient_accumulation_steps: 4
31+
micro_batch_size: 2
32+
num_epochs: 1
33+
optimizer: paged_adamw_32bit
34+
lr_scheduler: cosine
35+
learning_rate: 0.0002
36+
37+
bf16: auto
38+
tf32: false
39+
40+
gradient_checkpointing: true
41+
logging_steps: 1
42+
flash_attention: false
43+
44+
warmup_ratio: 0.1
45+
evals_per_epoch: 2
46+
saves_per_epoch: 1
47+
weight_decay: 0.0
48+
49+
special_tokens:
50+
pad_token: "<|end_of_text|>"

setup.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,12 @@ def get_package_version():
159159
"llmcompressor==0.5.1",
160160
],
161161
"fbgemm-gpu": ["fbgemm-gpu-genai>=1.2.0"],
162+
"opentelemetry": [
163+
"opentelemetry-api",
164+
"opentelemetry-sdk",
165+
"opentelemetry-exporter-prometheus",
166+
"prometheus-client",
167+
],
162168
}
163169
install_requires, dependency_links, extras_require_build = parse_requirements(
164170
extras_require

src/axolotl/core/builders/base.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@
2929

3030
from axolotl.integrations.base import PluginManager
3131
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
32-
from axolotl.utils import is_comet_available, is_mlflow_available
32+
from axolotl.utils import (
33+
is_comet_available,
34+
is_mlflow_available,
35+
is_opentelemetry_available,
36+
)
3337
from axolotl.utils.callbacks import (
3438
GCCallback,
3539
SaveAxolotlConfigtoWandBCallback,
@@ -134,6 +138,12 @@ def get_callbacks(self) -> list[TrainerCallback]:
134138
callbacks.append(
135139
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
136140
)
141+
if self.cfg.use_otel_metrics and is_opentelemetry_available():
142+
from axolotl.utils.callbacks.opentelemetry import (
143+
OpenTelemetryMetricsCallback,
144+
)
145+
146+
callbacks.append(OpenTelemetryMetricsCallback(self.cfg))
137147
if self.cfg.save_first_step:
138148
callbacks.append(SaveModelOnFirstStepCallback())
139149

src/axolotl/utils/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ def is_comet_available():
1717
return importlib.util.find_spec("comet_ml") is not None
1818

1919

20+
def is_opentelemetry_available():
21+
return (
22+
importlib.util.find_spec("opentelemetry") is not None
23+
and importlib.util.find_spec("prometheus_client") is not None
24+
)
25+
26+
2027
def get_pytorch_version() -> tuple[int, int, int]:
2128
"""
2229
Get Pytorch version as a tuple of (major, minor, patch).
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
"""OpenTelemetry metrics callback for Axolotl training"""
2+
3+
import threading
4+
from typing import Dict, Optional
5+
6+
from transformers import (
7+
TrainerCallback,
8+
TrainerControl,
9+
TrainerState,
10+
TrainingArguments,
11+
)
12+
13+
from axolotl.utils.logging import get_logger
14+
15+
LOG = get_logger(__name__)
16+
17+
try:
18+
from opentelemetry import metrics
19+
from opentelemetry.exporter.prometheus import PrometheusMetricReader
20+
from opentelemetry.metrics import set_meter_provider
21+
from opentelemetry.sdk.metrics import MeterProvider as SDKMeterProvider
22+
from prometheus_client import start_http_server
23+
24+
OPENTELEMETRY_AVAILABLE = True
25+
except ImportError:
26+
LOG.warning("OpenTelemetry not available. pip install [opentelemetry]")
27+
OPENTELEMETRY_AVAILABLE = False
28+
29+
30+
class OpenTelemetryMetricsCallback(TrainerCallback):
31+
"""
32+
TrainerCallback that exports training metrics to OpenTelemetry/Prometheus.
33+
34+
This callback automatically tracks key training metrics including:
35+
- Training loss
36+
- Evaluation loss
37+
- Learning rate
38+
- Epoch progress
39+
- Global step count
40+
- Gradient norm
41+
42+
Metrics are exposed via HTTP endpoint for Prometheus scraping.
43+
"""
44+
45+
def __init__(self, cfg):
46+
if not OPENTELEMETRY_AVAILABLE:
47+
LOG.warning("OpenTelemetry not available, metrics will not be collected")
48+
self.metrics_enabled = False
49+
return
50+
51+
self.cfg = cfg
52+
self.metrics_host = getattr(cfg, "otel_metrics_host", "localhost")
53+
self.metrics_port = getattr(cfg, "otel_metrics_port", 8000)
54+
self.metrics_enabled = True
55+
self.server_started = False
56+
self.metrics_lock = threading.Lock()
57+
58+
try:
59+
# Create Prometheus metrics reader
60+
prometheus_reader = PrometheusMetricReader()
61+
62+
# Create meter provider with Prometheus exporter
63+
provider = SDKMeterProvider(metric_readers=[prometheus_reader])
64+
set_meter_provider(provider)
65+
66+
# Get meter for creating metrics
67+
self.meter = metrics.get_meter("axolotl.training")
68+
69+
# Create metrics
70+
self._create_metrics()
71+
72+
except Exception as e:
73+
LOG.warning(f"Failed to initialize OpenTelemetry metrics: {e}")
74+
self.metrics_enabled = False
75+
76+
def _create_metrics(self):
77+
"""Create all metrics that will be tracked"""
78+
self.train_loss_gauge = self.meter.create_gauge(
79+
name="axolotl_train_loss",
80+
description="Current training loss",
81+
unit="1",
82+
)
83+
84+
self.eval_loss_gauge = self.meter.create_gauge(
85+
name="axolotl_eval_loss",
86+
description="Current evaluation loss",
87+
unit="1",
88+
)
89+
90+
self.learning_rate_gauge = self.meter.create_gauge(
91+
name="axolotl_learning_rate",
92+
description="Current learning rate",
93+
unit="1",
94+
)
95+
96+
self.epoch_gauge = self.meter.create_gauge(
97+
name="axolotl_epoch",
98+
description="Current training epoch",
99+
unit="1",
100+
)
101+
102+
self.global_step_counter = self.meter.create_counter(
103+
name="axolotl_global_steps",
104+
description="Total training steps completed",
105+
unit="1",
106+
)
107+
108+
self.grad_norm_gauge = self.meter.create_gauge(
109+
name="axolotl_gradient_norm",
110+
description="Gradient norm",
111+
unit="1",
112+
)
113+
114+
self.memory_usage_gauge = self.meter.create_gauge(
115+
name="axolotl_memory_usage",
116+
description="Current memory usage in MB",
117+
unit="MB",
118+
)
119+
120+
def _start_metrics_server(self):
121+
"""Start the HTTP server for metrics exposure"""
122+
if self.server_started:
123+
return
124+
125+
try:
126+
start_http_server(self.metrics_port, addr=self.metrics_host)
127+
self.server_started = True
128+
LOG.info(
129+
f"OpenTelemetry metrics server started on http://{self.metrics_host}:{self.metrics_port}/metrics"
130+
)
131+
132+
except Exception as e:
133+
LOG.error(f"Failed to start OpenTelemetry metrics server: {e}")
134+
135+
def on_train_begin(
136+
self,
137+
args: TrainingArguments,
138+
state: TrainerState,
139+
control: TrainerControl,
140+
**kwargs,
141+
):
142+
"""Called at the beginning of training"""
143+
if not self.metrics_enabled:
144+
return
145+
146+
self._start_metrics_server()
147+
LOG.info("OpenTelemetry metrics collection started")
148+
149+
def on_log(
150+
self,
151+
args: TrainingArguments,
152+
state: TrainerState,
153+
control: TrainerControl,
154+
logs: Optional[Dict[str, float]] = None,
155+
**kwargs,
156+
):
157+
"""Called when logging occurs"""
158+
if not self.metrics_enabled or not logs:
159+
return
160+
161+
if "loss" in logs:
162+
self.train_loss_gauge.set(logs["loss"])
163+
164+
if "eval_loss" in logs:
165+
self.eval_loss_gauge.set(logs["eval_loss"])
166+
167+
if "learning_rate" in logs:
168+
self.learning_rate_gauge.set(logs["learning_rate"])
169+
170+
if "epoch" in logs:
171+
self.epoch_gauge.set(logs["epoch"])
172+
173+
if "grad_norm" in logs:
174+
self.grad_norm_gauge.set(logs["grad_norm"])
175+
if "memory_usage" in logs:
176+
self.memory_usage_gauge.set(logs["memory_usage"])
177+
178+
def on_step_end(
179+
self,
180+
args: TrainingArguments,
181+
state: TrainerState,
182+
control: TrainerControl,
183+
**kwargs,
184+
):
185+
"""Called at the end of each training step"""
186+
if not self.metrics_enabled:
187+
return
188+
189+
# Update step counter and epoch
190+
self.global_step_counter.add(1)
191+
if state.epoch is not None:
192+
self.epoch_gauge.set(state.epoch)
193+
194+
def on_evaluate(
195+
self,
196+
args: TrainingArguments,
197+
state: TrainerState,
198+
control: TrainerControl,
199+
metrics: Optional[Dict[str, float]] = None,
200+
**kwargs,
201+
):
202+
"""Called after evaluation"""
203+
if not self.metrics_enabled or not metrics:
204+
return
205+
206+
if "eval_loss" in metrics:
207+
self.eval_loss_gauge.set(metrics["eval_loss"])
208+
209+
# Record any other eval metrics as gauges
210+
for key, value in metrics.items():
211+
if key.startswith("eval_") and isinstance(value, (int, float)):
212+
# Create gauge for this metric if it doesn't exist
213+
gauge_name = f"axolotl_{key}"
214+
try:
215+
gauge = self.meter.create_gauge(
216+
name=gauge_name,
217+
description=f"Evaluation metric: {key}",
218+
unit="1",
219+
)
220+
gauge.set(value)
221+
except Exception as e:
222+
LOG.warning(f"Failed to create/update metric {gauge_name}: {e}")
223+
224+
def on_train_end(
225+
self,
226+
args: TrainingArguments,
227+
state: TrainerState,
228+
control: TrainerControl,
229+
**kwargs,
230+
):
231+
"""Called at the end of training"""
232+
if not self.metrics_enabled:
233+
return
234+
235+
LOG.info("Training completed. OpenTelemetry metrics collection finished.")
236+
LOG.info(
237+
f"Metrics are still available at http://{self.metrics_host}:{self.metrics_port}/metrics"
238+
)

src/axolotl/utils/schemas/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
GradioConfig,
3131
LISAConfig,
3232
MLFlowConfig,
33+
OpenTelemetryConfig,
3334
RayConfig,
3435
WandbConfig,
3536
)
@@ -60,6 +61,7 @@ class AxolotlInputConfig(
6061
WandbConfig,
6162
MLFlowConfig,
6263
CometConfig,
64+
OpenTelemetryConfig,
6365
LISAConfig,
6466
GradioConfig,
6567
RayConfig,

src/axolotl/utils/schemas/integrations.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,27 @@ class RayConfig(BaseModel):
176176
"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."
177177
},
178178
)
179+
180+
181+
class OpenTelemetryConfig(BaseModel):
182+
"""OpenTelemetry configuration subset"""
183+
184+
use_otel_metrics: bool | None = Field(
185+
default=False,
186+
json_schema_extra={
187+
"description": "Enable OpenTelemetry metrics collection and Prometheus export"
188+
},
189+
)
190+
otel_metrics_host: str | None = Field(
191+
default="localhost",
192+
json_schema_extra={
193+
"title": "OpenTelemetry Metrics Host",
194+
"description": "Host to bind the OpenTelemetry metrics server to",
195+
},
196+
)
197+
otel_metrics_port: int | None = Field(
198+
default=8000,
199+
json_schema_extra={
200+
"description": "Port for the Prometheus metrics HTTP server"
201+
},
202+
)

0 commit comments

Comments
 (0)