Skip to content

Commit 19d9ee6

Browse files
RobotSailclaude
andauthored
Add MLflow logging support (#66)
* Add MLflow logging support - Add mlflow_wrapper.py for optional MLflow imports with error handling - Add mlflow_tracking_uri, mlflow_experiment_name, mlflow_run_name to TrainingArgs - Update AsyncStructuredLogger to support MLflow logging - Add MLflow CLI args to api_train.py and train.py - Initialize MLflow at start, log params, finish at end of training Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Format code with ruff * Address PR review comments - Enable MLflow when any MLflow arg is provided (not just tracking_uri) - Only init/finish MLflow on global rank 0 to avoid multiple runs in multi-node Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Fix MLflow logging to use explicit run ID Store the run ID when initializing MLflow and use it explicitly when logging params/metrics. This fixes an issue where async logging would lose the thread-local run context and create a separate MLflow run. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Fix MLflow logging - don't re-start already active run The previous fix incorrectly tried to start a run in log_params() when the run was already active from init(). Now log_params() logs directly since the run is already active, and log() only resumes the run if it's not currently active (for async contexts). Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Address PR review nitpicks - Guard dist.get_rank() when process group isn't initialized in async log() - Add mlflow as optional dependency in pyproject.toml Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add explicit environment variable fallback for MLflow configuration Implement kwarg > env var precedence for mlflow_tracking_uri and mlflow_experiment_name, matching the behavior of instructlab-training. The configuration now follows this precedence: 1. Explicit kwargs (highest priority) 2. Environment variables (MLFLOW_TRACKING_URI, MLFLOW_EXPERIMENT_NAME) 3. MLflow defaults (lowest priority) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Add async-safe pattern to log_params for thread-local context handling Mirror the pattern used in log() to handle cases where thread-local MLflow context is lost in async contexts. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Fix MLflow run context handling - don't use context manager for resume Using `with mlflow.start_run(run_id=...)` as a context manager ends the run when the block exits, breaking subsequent logging calls. Changed to call start_run() without context manager to keep the run active. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Format mlflow_wrapper.py with ruff Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * guard against active mlflow runs * comment * provide instructions when loggers are not available but user requests it * messaging * updates --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent f122c80 commit 19d9ee6

File tree

6 files changed

+273
-27
lines changed

6 files changed

+273
-27
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ dev = [
6262
"tox-uv",
6363
]
6464
wandb = ["wandb"]
65+
mlflow = ["mlflow>=3.0"]
6566

6667
[project.urls]
6768
Homepage = "https://github.com/Red-Hat-AI-Innovation-Team/mini_trainer"

src/mini_trainer/api_train.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,16 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
154154
if train_args.wandb_entity:
155155
command.append(f"--wandb-entity={train_args.wandb_entity}")
156156

157+
# mlflow-related arguments
158+
if train_args.mlflow_tracking_uri:
159+
command.append(f"--mlflow-tracking-uri={train_args.mlflow_tracking_uri}")
160+
if train_args.mlflow_experiment_name:
161+
command.append(
162+
f"--mlflow-experiment-name={train_args.mlflow_experiment_name}"
163+
)
164+
if train_args.mlflow_run_name:
165+
command.append(f"--mlflow-run-name={train_args.mlflow_run_name}")
166+
157167
# validation-related arguments
158168
if train_args.validation_split > 0.0:
159169
command.append(f"--validation-split={train_args.validation_split}")

src/mini_trainer/async_structured_logger.py

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,28 @@
1414
from tqdm import tqdm
1515

1616
# Local imports
17-
from mini_trainer import wandb_wrapper
17+
from mini_trainer import wandb_wrapper, mlflow_wrapper
1818
from mini_trainer.wandb_wrapper import check_wandb_available
19-
19+
from mini_trainer.mlflow_wrapper import check_mlflow_available
2020

2121

2222
class AsyncStructuredLogger:
23-
def __init__(self, file_name="training_log.jsonl", use_wandb=False):
23+
def __init__(
24+
self, file_name="training_log.jsonl", use_wandb=False, use_mlflow=False
25+
):
2426
self.file_name = file_name
25-
27+
2628
# wandb init is a special case -- if it is requested but unavailable,
2729
# we should error out early
2830
if use_wandb:
2931
check_wandb_available("initialize wandb")
3032
self.use_wandb = use_wandb
3133

34+
# mlflow init - same pattern as wandb
35+
if use_mlflow:
36+
check_mlflow_available("initialize mlflow")
37+
self.use_mlflow = use_mlflow
38+
3239
# Rich console for prettier output (force_terminal=True works with subprocess streaming)
3340
self.console = Console(force_terminal=True, force_interactive=False)
3441

@@ -67,23 +74,35 @@ async def log(self, data):
6774
data["timestamp"] = datetime.now().isoformat()
6875
self.logs.append(data)
6976
await self._write_logs_to_file(data)
70-
71-
# log to wandb if enabled and wandb is initialized, but only log this on the MAIN rank
77+
78+
# log to wandb/mlflow if enabled, but only log this on the MAIN rank
79+
# Guard rank checks when the process group isn't initialized (single-process runs)
80+
is_rank0 = not dist.is_initialized() or dist.get_rank() == 0
81+
7282
# wandb already handles timestamps so no need to include
73-
if self.use_wandb and dist.get_rank() == 0:
83+
if self.use_wandb and is_rank0:
7484
wandb_data = {k: v for k, v in data.items() if k != "timestamp"}
7585
wandb_wrapper.log(wandb_data)
86+
87+
# log to mlflow if enabled, only on MAIN rank
88+
# Filter out step from data since it's passed as a separate argument
89+
if self.use_mlflow and is_rank0:
90+
step = data.get("step")
91+
mlflow_data = {
92+
k: v for k, v in data.items() if k not in ("timestamp", "step")
93+
}
94+
mlflow_wrapper.log(mlflow_data, step=step)
7695
except Exception as e:
7796
print(f"\033[1;38;2;0;255;255mError logging data: {e}\033[0m")
7897

7998
async def _write_logs_to_file(self, data):
8099
"""appends to the log instead of writing the whole log each time"""
81100
async with aiofiles.open(self.file_name, "a") as f:
82101
await f.write(json.dumps(data, indent=None) + "\n")
83-
102+
84103
def log_sync(self, data: dict):
85104
"""Runs the log coroutine non-blocking and prints metrics with tqdm-styled progress bar.
86-
105+
87106
Args:
88107
data: Dictionary of metrics to log. Will automatically print a tqdm-formatted
89108
progress bar with ANSI colors if step and steps_per_epoch are present.
@@ -96,61 +115,61 @@ def log_sync(self, data: dict):
96115
should_print = not dist.is_initialized() or dist.get_rank() == 0
97116
if should_print:
98117
data_with_timestamp = {**data, "timestamp": datetime.now().isoformat()}
99-
118+
100119
# Print the JSON using Rich for syntax highlighting
101120
self.console.print_json(json.dumps(data_with_timestamp))
102-
121+
103122
# Print tqdm-styled progress bar after JSON (prints as new line each time)
104123
# This works correctly with subprocess streaming
105-
if 'step' in data and 'steps_per_epoch' in data and 'epoch' in data:
124+
if "step" in data and "steps_per_epoch" in data and "epoch" in data:
106125
# Initialize tqdm on first call (lazy init to avoid early printing)
107126
if self.train_pbar is None:
108127
# Simple bar format with ANSI colors - we'll add epoch and metrics manually
109128
self.train_bar_format = (
110-
'{bar} '
111-
'\033[33m{percentage:3.0f}%\033[0m │ '
112-
'\033[37m{n}/{total}\033[0m'
129+
"{bar} "
130+
"\033[33m{percentage:3.0f}%\033[0m │ "
131+
"\033[37m{n}/{total}\033[0m"
113132
)
114133
self.train_pbar = tqdm(
115-
total=data['steps_per_epoch'],
134+
total=data["steps_per_epoch"],
116135
bar_format=self.train_bar_format,
117136
ncols=None,
118137
leave=False,
119138
position=0,
120139
file=sys.stdout,
121-
ascii='━╺─', # custom characters matching Rich style
140+
ascii="━╺─", # custom characters matching Rich style
122141
disable=True, # disable auto-display, we'll manually call display()
123142
)
124143

125144
# Reset tqdm if we're in a new epoch
126-
current_step_in_epoch = (data['step'] - 1) % data['steps_per_epoch'] + 1
145+
current_step_in_epoch = (data["step"] - 1) % data["steps_per_epoch"] + 1
127146
if current_step_in_epoch == 1:
128-
self.train_pbar.reset(total=data['steps_per_epoch'])
147+
self.train_pbar.reset(total=data["steps_per_epoch"])
129148

130149
# Update tqdm position
131150
self.train_pbar.n = current_step_in_epoch
132151

133152
# Manually format the complete progress line with metrics using format_meter
134153
bar_str = self.train_pbar.format_meter(
135154
n=current_step_in_epoch,
136-
total=data['steps_per_epoch'],
155+
total=data["steps_per_epoch"],
137156
elapsed=0, # we don't track elapsed time
138157
ncols=None,
139158
bar_format=self.train_bar_format,
140-
ascii='━╺─',
159+
ascii="━╺─",
141160
)
142161

143162
# Prepend the epoch number (1-indexed)
144-
epoch_prefix = f'\033[1;34mEpoch {data["epoch"] + 1}:\033[0m '
163+
epoch_prefix = f"\033[1;34mEpoch {data['epoch'] + 1}:\033[0m "
145164
bar_str = epoch_prefix + bar_str
146-
165+
147166
# Add the metrics to the bar string
148167
metrics_str = (
149168
f" │ \033[32mloss:\033[0m \033[37m{data['loss']:.4f}\033[0m"
150169
f" │ \033[32mlr:\033[0m \033[37m{data['lr']:.2e}\033[0m"
151170
f" │ \033[35m{data['tokens_per_second']:.0f}\033[0m \033[2mtok/s\033[0m"
152171
)
153-
172+
154173
# Print the complete line
155174
print(bar_str + metrics_str, file=sys.stdout, flush=True)
156175

src/mini_trainer/mlflow_wrapper.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
"""
4+
Wrapper for optional mlflow imports that provides consistent error handling
5+
across all processes when mlflow is not installed.
6+
"""
7+
8+
import logging
9+
import os
10+
from typing import Any, Dict, Optional
11+
12+
# Try to import mlflow
13+
try:
14+
import mlflow
15+
16+
MLFLOW_AVAILABLE = True
17+
except ImportError:
18+
MLFLOW_AVAILABLE = False
19+
mlflow = None
20+
21+
logger = logging.getLogger(__name__)
22+
23+
# Store the active run ID to ensure we can resume the run if needed
24+
# This is needed because async logging may lose the thread-local run context
25+
_active_run_id: Optional[str] = None
26+
27+
28+
class MLflowNotAvailableError(ImportError):
29+
"""Raised when mlflow functions are called but mlflow is not installed."""
30+
31+
pass
32+
33+
34+
def check_mlflow_available(operation: str) -> None:
35+
"""Check if mlflow is available, raise error if not."""
36+
if not MLFLOW_AVAILABLE:
37+
error_msg = (
38+
f"Attempted to {operation} but mlflow is not installed. "
39+
"Please install mlflow with: pip install mlflow"
40+
)
41+
logger.error(error_msg)
42+
raise MLflowNotAvailableError(error_msg)
43+
44+
45+
def init(
46+
tracking_uri: Optional[str] = None,
47+
experiment_name: Optional[str] = None,
48+
run_name: Optional[str] = None,
49+
**kwargs,
50+
) -> Any:
51+
"""
52+
Initialize an mlflow run. Raises MLflowNotAvailableError if mlflow is not installed.
53+
54+
Configuration follows a precedence hierarchy:
55+
1. Explicit kwargs (highest priority)
56+
2. Environment variables (MLFLOW_TRACKING_URI, MLFLOW_EXPERIMENT_NAME)
57+
3. MLflow defaults (lowest priority)
58+
59+
Args:
60+
tracking_uri: MLflow tracking server URI (e.g., "http://localhost:5000").
61+
Falls back to MLFLOW_TRACKING_URI environment variable if not provided.
62+
experiment_name: Name of the experiment.
63+
Falls back to MLFLOW_EXPERIMENT_NAME environment variable if not provided.
64+
run_name: Name of the run
65+
**kwargs: Additional arguments to pass to mlflow.start_run
66+
67+
Returns:
68+
mlflow.ActiveRun object if successful
69+
70+
Raises:
71+
MLflowNotAvailableError: If mlflow is not installed
72+
"""
73+
global _active_run_id
74+
check_mlflow_available("initialize mlflow")
75+
76+
# Apply kwarg > env var precedence for tracking_uri
77+
effective_tracking_uri = tracking_uri or os.environ.get("MLFLOW_TRACKING_URI")
78+
if effective_tracking_uri:
79+
mlflow.set_tracking_uri(effective_tracking_uri)
80+
81+
# Apply kwarg > env var precedence for experiment_name
82+
effective_experiment_name = experiment_name or os.environ.get(
83+
"MLFLOW_EXPERIMENT_NAME"
84+
)
85+
if effective_experiment_name:
86+
mlflow.set_experiment(effective_experiment_name)
87+
88+
# Remove run_name from kwargs if present to avoid duplicate keyword argument
89+
# The explicit run_name parameter takes precedence
90+
kwargs.pop("run_name", None)
91+
92+
# Reuse existing active run if one exists, otherwise start a new one
93+
active_run = mlflow.active_run()
94+
if active_run is not None:
95+
run = active_run
96+
else:
97+
run = mlflow.start_run(run_name=run_name, **kwargs)
98+
_active_run_id = run.info.run_id
99+
return run
100+
101+
102+
def get_active_run_id() -> Optional[str]:
103+
"""Get the active run ID that was started by init()."""
104+
return _active_run_id
105+
106+
107+
def _ensure_run_for_logging() -> None:
108+
"""Ensure there's an active MLflow run for logging.
109+
110+
This helper handles async contexts where thread-local run context may be lost.
111+
If no active run exists but we have a stored run ID, it resumes that run.
112+
"""
113+
active_run = mlflow.active_run()
114+
if not active_run and _active_run_id:
115+
# No active run in this thread but we have a stored run ID - resume it
116+
# This can happen in async contexts where thread-local context is lost
117+
# Note: We don't use context manager here because it would end the run on exit
118+
mlflow.start_run(run_id=_active_run_id)
119+
120+
121+
def log_params(params: Dict[str, Any]) -> None:
122+
"""
123+
Log parameters to mlflow. Raises MLflowNotAvailableError if mlflow is not installed.
124+
125+
Args:
126+
params: Dictionary of parameters to log
127+
128+
Raises:
129+
MLflowNotAvailableError: If mlflow is not installed
130+
"""
131+
check_mlflow_available("log params to mlflow")
132+
# MLflow params must be strings
133+
str_params = {k: str(v) for k, v in params.items()}
134+
135+
_ensure_run_for_logging()
136+
mlflow.log_params(str_params)
137+
138+
139+
def log(data: Dict[str, Any], step: Optional[int] = None) -> None:
140+
"""
141+
Log metrics to mlflow. Raises MLflowNotAvailableError if mlflow is not installed.
142+
143+
Args:
144+
data: Dictionary of data to log (non-numeric values will be skipped)
145+
step: Optional step number for the metrics
146+
147+
Raises:
148+
MLflowNotAvailableError: If mlflow is not installed
149+
"""
150+
check_mlflow_available("log to mlflow")
151+
# Filter to only numeric values for metrics
152+
metrics = {}
153+
for k, v in data.items():
154+
try:
155+
metrics[k] = float(v)
156+
except (ValueError, TypeError):
157+
pass # Skip non-numeric values
158+
if metrics:
159+
_ensure_run_for_logging()
160+
mlflow.log_metrics(metrics, step=step)
161+
162+
163+
def finish() -> None:
164+
"""
165+
End the mlflow run. Raises MLflowNotAvailableError if mlflow is not installed.
166+
167+
Raises:
168+
MLflowNotAvailableError: If mlflow is not installed
169+
"""
170+
global _active_run_id
171+
check_mlflow_available("finish mlflow run")
172+
mlflow.end_run()
173+
_active_run_id = None

0 commit comments

Comments
 (0)