Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ dev = [
"tox-uv",
]
wandb = ["wandb"]
mlflow = ["mlflow>=3.0"]

[project.urls]
Homepage = "https://github.com/Red-Hat-AI-Innovation-Team/mini_trainer"
Expand Down
10 changes: 10 additions & 0 deletions src/mini_trainer/api_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,16 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
if train_args.wandb_entity:
command.append(f"--wandb-entity={train_args.wandb_entity}")

# mlflow-related arguments
if train_args.mlflow_tracking_uri:
command.append(f"--mlflow-tracking-uri={train_args.mlflow_tracking_uri}")
if train_args.mlflow_experiment_name:
command.append(
f"--mlflow-experiment-name={train_args.mlflow_experiment_name}"
)
if train_args.mlflow_run_name:
command.append(f"--mlflow-run-name={train_args.mlflow_run_name}")
Comment on lines +157 to +165
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cat -n src/mini_trainer/api_train.py | sed -n '150,175p'

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 1447


🏁 Script executed:

# Get broader context around the MLflow configuration
cat -n src/mini_trainer/api_train.py | sed -n '140,180p'

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 2126


🏁 Script executed:

# Check if there are other MLflow references or documentation in the repo
rg -i "mlflow" --type py -A 2 -B 2 src/mini_trainer/api_train.py | head -80

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 628


🏁 Script executed:

# Find where the MLflow CLI arguments are parsed/handled
rg "mlflow-tracking-uri|mlflow-experiment-name|mlflow-run-name" --type py -B 3 -A 3

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 988


🏁 Script executed:

# Check the training script that receives these arguments
fd "train\.py" src/ -type f

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 253


🏁 Script executed:

# Look for MLflow initialization or argument parsing in the codebase
rg -i "mlflow" --type py -B 2 -A 2 | head -100

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 6208


🏁 Script executed:

cat src/mini_trainer/mlflow_wrapper.py

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 5599


MLflow experiment and run name args should be independent of tracking URI.

Lines 160–165 only append --mlflow-experiment-name and --mlflow-run-name when mlflow_tracking_uri is set. However, the training code enables MLflow if any of the three MLflow args are provided (use_mlflow = any([mlflow_tracking_uri, mlflow_experiment_name, mlflow_run_name])), and MLflow supports using the default local backend without an explicit tracking URI. Move these two conditionals outside the mlflow_tracking_uri block so they append independently.

🛠️ Suggested fix
-    if train_args.mlflow_tracking_uri:
-        command.append(f"--mlflow-tracking-uri={train_args.mlflow_tracking_uri}")
-        if train_args.mlflow_experiment_name:
-            command.append(
-                f"--mlflow-experiment-name={train_args.mlflow_experiment_name}"
-            )
-        if train_args.mlflow_run_name:
-            command.append(f"--mlflow-run-name={train_args.mlflow_run_name}")
+    if train_args.mlflow_tracking_uri:
+        command.append(f"--mlflow-tracking-uri={train_args.mlflow_tracking_uri}")
+    if train_args.mlflow_experiment_name:
+        command.append(
+            f"--mlflow-experiment-name={train_args.mlflow_experiment_name}"
+        )
+    if train_args.mlflow_run_name:
+        command.append(f"--mlflow-run-name={train_args.mlflow_run_name}")
🤖 Prompt for AI Agents
In `@src/mini_trainer/api_train.py` around lines 157 - 165, The MLflow
experiment/run name flags are currently only added when
train_args.mlflow_tracking_uri is set; update the logic in
src/mini_trainer/api_train.py so that the appending of --mlflow-experiment-name
and --mlflow-run-name is done independently of the --mlflow-tracking-uri block:
keep the existing append of
f"--mlflow-tracking-uri={train_args.mlflow_tracking_uri}" inside the if
train_args.mlflow_tracking_uri: block, but move the if
train_args.mlflow_experiment_name: and if train_args.mlflow_run_name: checks out
of that block so they append their flags to the command list even when
mlflow_tracking_uri is None (consistent with use_mlflow = any([...])).


# validation-related arguments
if train_args.validation_split > 0.0:
command.append(f"--validation-split={train_args.validation_split}")
Expand Down
67 changes: 43 additions & 24 deletions src/mini_trainer/async_structured_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,28 @@
from tqdm import tqdm

# Local imports
from mini_trainer import wandb_wrapper
from mini_trainer import wandb_wrapper, mlflow_wrapper
from mini_trainer.wandb_wrapper import check_wandb_available

from mini_trainer.mlflow_wrapper import check_mlflow_available


class AsyncStructuredLogger:
def __init__(self, file_name="training_log.jsonl", use_wandb=False):
def __init__(
self, file_name="training_log.jsonl", use_wandb=False, use_mlflow=False
):
self.file_name = file_name

# wandb init is a special case -- if it is requested but unavailable,
# we should error out early
if use_wandb:
check_wandb_available("initialize wandb")
self.use_wandb = use_wandb

# mlflow init - same pattern as wandb
if use_mlflow:
check_mlflow_available("initialize mlflow")
self.use_mlflow = use_mlflow

# Rich console for prettier output (force_terminal=True works with subprocess streaming)
self.console = Console(force_terminal=True, force_interactive=False)

Expand Down Expand Up @@ -67,23 +74,35 @@ async def log(self, data):
data["timestamp"] = datetime.now().isoformat()
self.logs.append(data)
await self._write_logs_to_file(data)

# log to wandb if enabled and wandb is initialized, but only log this on the MAIN rank

# log to wandb/mlflow if enabled, but only log this on the MAIN rank
# Guard rank checks when the process group isn't initialized (single-process runs)
is_rank0 = not dist.is_initialized() or dist.get_rank() == 0

# wandb already handles timestamps so no need to include
if self.use_wandb and dist.get_rank() == 0:
if self.use_wandb and is_rank0:
wandb_data = {k: v for k, v in data.items() if k != "timestamp"}
wandb_wrapper.log(wandb_data)

# log to mlflow if enabled, only on MAIN rank
# Filter out step from data since it's passed as a separate argument
if self.use_mlflow and is_rank0:
step = data.get("step")
mlflow_data = {
k: v for k, v in data.items() if k not in ("timestamp", "step")
}
mlflow_wrapper.log(mlflow_data, step=step)
except Exception as e:
print(f"\033[1;38;2;0;255;255mError logging data: {e}\033[0m")

async def _write_logs_to_file(self, data):
"""appends to the log instead of writing the whole log each time"""
async with aiofiles.open(self.file_name, "a") as f:
await f.write(json.dumps(data, indent=None) + "\n")

def log_sync(self, data: dict):
"""Runs the log coroutine non-blocking and prints metrics with tqdm-styled progress bar.

Args:
data: Dictionary of metrics to log. Will automatically print a tqdm-formatted
progress bar with ANSI colors if step and steps_per_epoch are present.
Expand All @@ -96,61 +115,61 @@ def log_sync(self, data: dict):
should_print = not dist.is_initialized() or dist.get_rank() == 0
if should_print:
data_with_timestamp = {**data, "timestamp": datetime.now().isoformat()}

# Print the JSON using Rich for syntax highlighting
self.console.print_json(json.dumps(data_with_timestamp))

# Print tqdm-styled progress bar after JSON (prints as new line each time)
# This works correctly with subprocess streaming
if 'step' in data and 'steps_per_epoch' in data and 'epoch' in data:
if "step" in data and "steps_per_epoch" in data and "epoch" in data:
# Initialize tqdm on first call (lazy init to avoid early printing)
if self.train_pbar is None:
# Simple bar format with ANSI colors - we'll add epoch and metrics manually
self.train_bar_format = (
'{bar} '
'\033[33m{percentage:3.0f}%\033[0m │ '
'\033[37m{n}/{total}\033[0m'
"{bar} "
"\033[33m{percentage:3.0f}%\033[0m │ "
"\033[37m{n}/{total}\033[0m"
)
self.train_pbar = tqdm(
total=data['steps_per_epoch'],
total=data["steps_per_epoch"],
bar_format=self.train_bar_format,
ncols=None,
leave=False,
position=0,
file=sys.stdout,
ascii='━╺─', # custom characters matching Rich style
ascii="━╺─", # custom characters matching Rich style
disable=True, # disable auto-display, we'll manually call display()
)

# Reset tqdm if we're in a new epoch
current_step_in_epoch = (data['step'] - 1) % data['steps_per_epoch'] + 1
current_step_in_epoch = (data["step"] - 1) % data["steps_per_epoch"] + 1
if current_step_in_epoch == 1:
self.train_pbar.reset(total=data['steps_per_epoch'])
self.train_pbar.reset(total=data["steps_per_epoch"])

# Update tqdm position
self.train_pbar.n = current_step_in_epoch

# Manually format the complete progress line with metrics using format_meter
bar_str = self.train_pbar.format_meter(
n=current_step_in_epoch,
total=data['steps_per_epoch'],
total=data["steps_per_epoch"],
elapsed=0, # we don't track elapsed time
ncols=None,
bar_format=self.train_bar_format,
ascii='━╺─',
ascii="━╺─",
)

# Prepend the epoch number (1-indexed)
epoch_prefix = f'\033[1;34mEpoch {data["epoch"] + 1}:\033[0m '
epoch_prefix = f"\033[1;34mEpoch {data['epoch'] + 1}:\033[0m "
bar_str = epoch_prefix + bar_str

# Add the metrics to the bar string
metrics_str = (
f" │ \033[32mloss:\033[0m \033[37m{data['loss']:.4f}\033[0m"
f" │ \033[32mlr:\033[0m \033[37m{data['lr']:.2e}\033[0m"
f" │ \033[35m{data['tokens_per_second']:.0f}\033[0m \033[2mtok/s\033[0m"
)

# Print the complete line
print(bar_str + metrics_str, file=sys.stdout, flush=True)

Expand Down
173 changes: 173 additions & 0 deletions src/mini_trainer/mlflow_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# SPDX-License-Identifier: Apache-2.0

"""
Wrapper for optional mlflow imports that provides consistent error handling
across all processes when mlflow is not installed.
"""

import logging
import os
from typing import Any, Dict, Optional

# Try to import mlflow
try:
import mlflow

MLFLOW_AVAILABLE = True
except ImportError:
MLFLOW_AVAILABLE = False
mlflow = None

logger = logging.getLogger(__name__)

# Store the active run ID to ensure we can resume the run if needed
# This is needed because async logging may lose the thread-local run context
_active_run_id: Optional[str] = None


class MLflowNotAvailableError(ImportError):
"""Raised when mlflow functions are called but mlflow is not installed."""

pass


def check_mlflow_available(operation: str) -> None:
"""Check if mlflow is available, raise error if not."""
if not MLFLOW_AVAILABLE:
error_msg = (
f"Attempted to {operation} but mlflow is not installed. "
"Please install mlflow with: pip install mlflow"
)
logger.error(error_msg)
raise MLflowNotAvailableError(error_msg)
Comment on lines +28 to +42
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

find . -name "pyproject.toml" -type f | head -5

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 99


🏁 Script executed:

cat pyproject.toml

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 2785


🏁 Script executed:

sed -n '28,42p' src/mini_trainer/mlflow_wrapper.py

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 618


Align MLflow install hint with distribution name.

Line 39 references mini-trainer[mlflow], but the distribution name in pyproject.toml is rhai-innovation-mini-trainer. Update the error message to match.

🔧 Suggested update
-            "Please install mlflow with: pip install 'mini-trainer[mlflow]'"
+            "Please install mlflow with: pip install 'rhai-innovation-mini-trainer[mlflow]'"
🤖 Prompt for AI Agents
In `@src/mini_trainer/mlflow_wrapper.py` around lines 28 - 42, The error message
in check_mlflow_available currently suggests installing "mini-trainer[mlflow]"
which doesn't match the package name; update the error_msg string used in
check_mlflow_available (and raised via MLflowNotAvailableError) to instruct
installing "rhai-innovation-mini-trainer[mlflow]" instead, keeping the same
phrasing and logger.error call and ensuring MLFLOW_AVAILABLE and
MLflowNotAvailableError usage remain unchanged.



def init(
tracking_uri: Optional[str] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
**kwargs,
) -> Any:
"""
Initialize an mlflow run. Raises MLflowNotAvailableError if mlflow is not installed.

Configuration follows a precedence hierarchy:
1. Explicit kwargs (highest priority)
2. Environment variables (MLFLOW_TRACKING_URI, MLFLOW_EXPERIMENT_NAME)
3. MLflow defaults (lowest priority)

Args:
tracking_uri: MLflow tracking server URI (e.g., "http://localhost:5000").
Falls back to MLFLOW_TRACKING_URI environment variable if not provided.
experiment_name: Name of the experiment.
Falls back to MLFLOW_EXPERIMENT_NAME environment variable if not provided.
run_name: Name of the run
**kwargs: Additional arguments to pass to mlflow.start_run

Returns:
mlflow.ActiveRun object if successful

Raises:
MLflowNotAvailableError: If mlflow is not installed
"""
global _active_run_id
check_mlflow_available("initialize mlflow")

# Apply kwarg > env var precedence for tracking_uri
effective_tracking_uri = tracking_uri or os.environ.get("MLFLOW_TRACKING_URI")
if effective_tracking_uri:
mlflow.set_tracking_uri(effective_tracking_uri)

# Apply kwarg > env var precedence for experiment_name
effective_experiment_name = experiment_name or os.environ.get(
"MLFLOW_EXPERIMENT_NAME"
)
if effective_experiment_name:
mlflow.set_experiment(effective_experiment_name)

# Remove run_name from kwargs if present to avoid duplicate keyword argument
# The explicit run_name parameter takes precedence
kwargs.pop("run_name", None)

# Reuse existing active run if one exists, otherwise start a new one
active_run = mlflow.active_run()
if active_run is not None:
run = active_run
else:
run = mlflow.start_run(run_name=run_name, **kwargs)
_active_run_id = run.info.run_id
return run


def get_active_run_id() -> Optional[str]:
"""Get the active run ID that was started by init()."""
return _active_run_id
Comment on lines +34 to +104
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

find . -name "mlflow_wrapper.py" -type f

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 119


🏁 Script executed:

cat -n src/mini_trainer/mlflow_wrapper.py | head -150

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 5266


🌐 Web query:

MLflow start_run nested calls re-entrance active run behavior documentation

💡 Result:

  • Default (non-nested) behavior: If there’s already an active run in the current thread, calling mlflow.start_run() again raises an exception (“Run with UUID … is already active … end_run() … or start_run(nested=True)”). This is enforced by checking the thread-local active run stack and requiring nested=True to start another run while one is active. [1]

  • Nested runs (nested=True): mlflow.start_run(nested=True) creates a child run under the current active run (the parent is taken from the top of the active-run stack) and pushes the new run onto the stack; child/parent linkage is recorded via the mlflow.parentRunId tag. [1], [2]

  • Explicit parent (parent_run_id=..., newer API): You can start a run nested under a specific parent via start_run(parent_run_id=...). The parent must be ACTIVE, and if there’s already an active run, the provided parent_run_id must match the current active run’s ID (otherwise it errors). [1]

  • “Active run” semantics (re-entrance / stack):

    • mlflow.active_run() returns the top of a thread-local stack (the most recently started, not-yet-ended run in the current thread). [2]
    • mlflow.end_run() pops the top run from that stack and terminates it. [3]
    • When using with mlflow.start_run() as run: ..., the context manager’s __exit__ ends the run only if it’s still present in the active-run stack, checking by run ID (not object identity) to handle cases where tools stop and restart a run with the same ID (a practical “re-entrance” safeguard). [4]

Sources:
[1] MLflow latest mlflow.tracking.fluent.start_run implementation/docs (module source) (mlflow.org)
[2] MLflow latest mlflow.tracking.fluent.active_run docs/implementation (mlflow.org)
[3] MLflow latest mlflow.tracking.fluent.end_run docs/implementation (mlflow.org)
[4] MLflow latest ActiveRun.__exit__ behavior (module source) (mlflow.org)

Citations:


Add nested=True to start_run() calls in log_params() and log() functions

The log_params() (line 96) and log() (line 125) functions attempt to re-enter the active run via mlflow.start_run(run_id=_active_run_id), but MLflow raises an exception when start_run() is called without nested=True while a run is already active. This breaks the normal usage flow: init()log_params()/log()finish().

Pass nested=True to both start_run() calls to explicitly allow re-entrance into the same run:

Fix
# In log_params() - line 96
if _active_run_id:
    with mlflow.start_run(run_id=_active_run_id, nested=True):
        mlflow.log_params(str_params)

# In log() - line 125  
if _active_run_id:
    with mlflow.start_run(run_id=_active_run_id, nested=True):
        mlflow.log_metrics(metrics, step=step)
🤖 Prompt for AI Agents
In `@src/mini_trainer/mlflow_wrapper.py` around lines 33 - 78, log_params() and
log() re-enter an already active MLflow run using
mlflow.start_run(run_id=_active_run_id) which raises unless nested=True is set;
update both functions (log_params and log) so their with mlflow.start_run(...)
calls include nested=True (i.e., with mlflow.start_run(run_id=_active_run_id,
nested=True):) to allow re-entrancy into the existing run when _active_run_id is
set.



def _ensure_run_for_logging() -> None:
"""Ensure there's an active MLflow run for logging.

This helper handles async contexts where thread-local run context may be lost.
If no active run exists but we have a stored run ID, it resumes that run.
"""
active_run = mlflow.active_run()
if not active_run and _active_run_id:
# No active run in this thread but we have a stored run ID - resume it
# This can happen in async contexts where thread-local context is lost
# Note: We don't use context manager here because it would end the run on exit
mlflow.start_run(run_id=_active_run_id)


def log_params(params: Dict[str, Any]) -> None:
"""
Log parameters to mlflow. Raises MLflowNotAvailableError if mlflow is not installed.

Args:
params: Dictionary of parameters to log

Raises:
MLflowNotAvailableError: If mlflow is not installed
"""
check_mlflow_available("log params to mlflow")
# MLflow params must be strings
str_params = {k: str(v) for k, v in params.items()}

_ensure_run_for_logging()
mlflow.log_params(str_params)


def log(data: Dict[str, Any], step: Optional[int] = None) -> None:
"""
Log metrics to mlflow. Raises MLflowNotAvailableError if mlflow is not installed.

Args:
data: Dictionary of data to log (non-numeric values will be skipped)
step: Optional step number for the metrics

Raises:
MLflowNotAvailableError: If mlflow is not installed
"""
check_mlflow_available("log to mlflow")
# Filter to only numeric values for metrics
metrics = {}
for k, v in data.items():
try:
metrics[k] = float(v)
except (ValueError, TypeError):
pass # Skip non-numeric values
if metrics:
_ensure_run_for_logging()
mlflow.log_metrics(metrics, step=step)


def finish() -> None:
"""
End the mlflow run. Raises MLflowNotAvailableError if mlflow is not installed.

Raises:
MLflowNotAvailableError: If mlflow is not installed
"""
global _active_run_id
check_mlflow_available("finish mlflow run")
mlflow.end_run()
_active_run_id = None
Loading
Loading