diff --git a/pyproject.toml b/pyproject.toml index b492909..e29e621 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ dev = [ "tox-uv", ] wandb = ["wandb"] +mlflow = ["mlflow>=3.0"] [project.urls] Homepage = "https://github.com/Red-Hat-AI-Innovation-Team/mini_trainer" diff --git a/src/mini_trainer/api_train.py b/src/mini_trainer/api_train.py index 87f49b1..78152f4 100644 --- a/src/mini_trainer/api_train.py +++ b/src/mini_trainer/api_train.py @@ -154,6 +154,16 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: if train_args.wandb_entity: command.append(f"--wandb-entity={train_args.wandb_entity}") + # mlflow-related arguments + if train_args.mlflow_tracking_uri: + command.append(f"--mlflow-tracking-uri={train_args.mlflow_tracking_uri}") + if train_args.mlflow_experiment_name: + command.append( + f"--mlflow-experiment-name={train_args.mlflow_experiment_name}" + ) + if train_args.mlflow_run_name: + command.append(f"--mlflow-run-name={train_args.mlflow_run_name}") + # validation-related arguments if train_args.validation_split > 0.0: command.append(f"--validation-split={train_args.validation_split}") diff --git a/src/mini_trainer/async_structured_logger.py b/src/mini_trainer/async_structured_logger.py index c33263b..a00878b 100644 --- a/src/mini_trainer/async_structured_logger.py +++ b/src/mini_trainer/async_structured_logger.py @@ -14,21 +14,28 @@ from tqdm import tqdm # Local imports -from mini_trainer import wandb_wrapper +from mini_trainer import wandb_wrapper, mlflow_wrapper from mini_trainer.wandb_wrapper import check_wandb_available - +from mini_trainer.mlflow_wrapper import check_mlflow_available class AsyncStructuredLogger: - def __init__(self, file_name="training_log.jsonl", use_wandb=False): + def __init__( + self, file_name="training_log.jsonl", use_wandb=False, use_mlflow=False + ): self.file_name = file_name - + # wandb init is a special case -- if it is requested but unavailable, # we should error out early if use_wandb: check_wandb_available("initialize wandb") self.use_wandb = use_wandb + # mlflow init - same pattern as wandb + if use_mlflow: + check_mlflow_available("initialize mlflow") + self.use_mlflow = use_mlflow + # Rich console for prettier output (force_terminal=True works with subprocess streaming) self.console = Console(force_terminal=True, force_interactive=False) @@ -67,12 +74,24 @@ async def log(self, data): data["timestamp"] = datetime.now().isoformat() self.logs.append(data) await self._write_logs_to_file(data) - - # log to wandb if enabled and wandb is initialized, but only log this on the MAIN rank + + # log to wandb/mlflow if enabled, but only log this on the MAIN rank + # Guard rank checks when the process group isn't initialized (single-process runs) + is_rank0 = not dist.is_initialized() or dist.get_rank() == 0 + # wandb already handles timestamps so no need to include - if self.use_wandb and dist.get_rank() == 0: + if self.use_wandb and is_rank0: wandb_data = {k: v for k, v in data.items() if k != "timestamp"} wandb_wrapper.log(wandb_data) + + # log to mlflow if enabled, only on MAIN rank + # Filter out step from data since it's passed as a separate argument + if self.use_mlflow and is_rank0: + step = data.get("step") + mlflow_data = { + k: v for k, v in data.items() if k not in ("timestamp", "step") + } + mlflow_wrapper.log(mlflow_data, step=step) except Exception as e: print(f"\033[1;38;2;0;255;255mError logging data: {e}\033[0m") @@ -80,10 +99,10 @@ async def _write_logs_to_file(self, data): """appends to the log instead of writing the whole log each time""" async with aiofiles.open(self.file_name, "a") as f: await f.write(json.dumps(data, indent=None) + "\n") - + def log_sync(self, data: dict): """Runs the log coroutine non-blocking and prints metrics with tqdm-styled progress bar. - + Args: data: Dictionary of metrics to log. Will automatically print a tqdm-formatted progress bar with ANSI colors if step and steps_per_epoch are present. @@ -96,36 +115,36 @@ def log_sync(self, data: dict): should_print = not dist.is_initialized() or dist.get_rank() == 0 if should_print: data_with_timestamp = {**data, "timestamp": datetime.now().isoformat()} - + # Print the JSON using Rich for syntax highlighting self.console.print_json(json.dumps(data_with_timestamp)) - + # Print tqdm-styled progress bar after JSON (prints as new line each time) # This works correctly with subprocess streaming - if 'step' in data and 'steps_per_epoch' in data and 'epoch' in data: + if "step" in data and "steps_per_epoch" in data and "epoch" in data: # Initialize tqdm on first call (lazy init to avoid early printing) if self.train_pbar is None: # Simple bar format with ANSI colors - we'll add epoch and metrics manually self.train_bar_format = ( - '{bar} ' - '\033[33m{percentage:3.0f}%\033[0m │ ' - '\033[37m{n}/{total}\033[0m' + "{bar} " + "\033[33m{percentage:3.0f}%\033[0m │ " + "\033[37m{n}/{total}\033[0m" ) self.train_pbar = tqdm( - total=data['steps_per_epoch'], + total=data["steps_per_epoch"], bar_format=self.train_bar_format, ncols=None, leave=False, position=0, file=sys.stdout, - ascii='━╺─', # custom characters matching Rich style + ascii="━╺─", # custom characters matching Rich style disable=True, # disable auto-display, we'll manually call display() ) # Reset tqdm if we're in a new epoch - current_step_in_epoch = (data['step'] - 1) % data['steps_per_epoch'] + 1 + current_step_in_epoch = (data["step"] - 1) % data["steps_per_epoch"] + 1 if current_step_in_epoch == 1: - self.train_pbar.reset(total=data['steps_per_epoch']) + self.train_pbar.reset(total=data["steps_per_epoch"]) # Update tqdm position self.train_pbar.n = current_step_in_epoch @@ -133,24 +152,24 @@ def log_sync(self, data: dict): # Manually format the complete progress line with metrics using format_meter bar_str = self.train_pbar.format_meter( n=current_step_in_epoch, - total=data['steps_per_epoch'], + total=data["steps_per_epoch"], elapsed=0, # we don't track elapsed time ncols=None, bar_format=self.train_bar_format, - ascii='━╺─', + ascii="━╺─", ) # Prepend the epoch number (1-indexed) - epoch_prefix = f'\033[1;34mEpoch {data["epoch"] + 1}:\033[0m ' + epoch_prefix = f"\033[1;34mEpoch {data['epoch'] + 1}:\033[0m " bar_str = epoch_prefix + bar_str - + # Add the metrics to the bar string metrics_str = ( f" │ \033[32mloss:\033[0m \033[37m{data['loss']:.4f}\033[0m" f" │ \033[32mlr:\033[0m \033[37m{data['lr']:.2e}\033[0m" f" │ \033[35m{data['tokens_per_second']:.0f}\033[0m \033[2mtok/s\033[0m" ) - + # Print the complete line print(bar_str + metrics_str, file=sys.stdout, flush=True) diff --git a/src/mini_trainer/mlflow_wrapper.py b/src/mini_trainer/mlflow_wrapper.py new file mode 100644 index 0000000..1f43927 --- /dev/null +++ b/src/mini_trainer/mlflow_wrapper.py @@ -0,0 +1,173 @@ +# SPDX-License-Identifier: Apache-2.0 + +""" +Wrapper for optional mlflow imports that provides consistent error handling +across all processes when mlflow is not installed. +""" + +import logging +import os +from typing import Any, Dict, Optional + +# Try to import mlflow +try: + import mlflow + + MLFLOW_AVAILABLE = True +except ImportError: + MLFLOW_AVAILABLE = False + mlflow = None + +logger = logging.getLogger(__name__) + +# Store the active run ID to ensure we can resume the run if needed +# This is needed because async logging may lose the thread-local run context +_active_run_id: Optional[str] = None + + +class MLflowNotAvailableError(ImportError): + """Raised when mlflow functions are called but mlflow is not installed.""" + + pass + + +def check_mlflow_available(operation: str) -> None: + """Check if mlflow is available, raise error if not.""" + if not MLFLOW_AVAILABLE: + error_msg = ( + f"Attempted to {operation} but mlflow is not installed. " + "Please install mlflow with: pip install mlflow" + ) + logger.error(error_msg) + raise MLflowNotAvailableError(error_msg) + + +def init( + tracking_uri: Optional[str] = None, + experiment_name: Optional[str] = None, + run_name: Optional[str] = None, + **kwargs, +) -> Any: + """ + Initialize an mlflow run. Raises MLflowNotAvailableError if mlflow is not installed. + + Configuration follows a precedence hierarchy: + 1. Explicit kwargs (highest priority) + 2. Environment variables (MLFLOW_TRACKING_URI, MLFLOW_EXPERIMENT_NAME) + 3. MLflow defaults (lowest priority) + + Args: + tracking_uri: MLflow tracking server URI (e.g., "http://localhost:5000"). + Falls back to MLFLOW_TRACKING_URI environment variable if not provided. + experiment_name: Name of the experiment. + Falls back to MLFLOW_EXPERIMENT_NAME environment variable if not provided. + run_name: Name of the run + **kwargs: Additional arguments to pass to mlflow.start_run + + Returns: + mlflow.ActiveRun object if successful + + Raises: + MLflowNotAvailableError: If mlflow is not installed + """ + global _active_run_id + check_mlflow_available("initialize mlflow") + + # Apply kwarg > env var precedence for tracking_uri + effective_tracking_uri = tracking_uri or os.environ.get("MLFLOW_TRACKING_URI") + if effective_tracking_uri: + mlflow.set_tracking_uri(effective_tracking_uri) + + # Apply kwarg > env var precedence for experiment_name + effective_experiment_name = experiment_name or os.environ.get( + "MLFLOW_EXPERIMENT_NAME" + ) + if effective_experiment_name: + mlflow.set_experiment(effective_experiment_name) + + # Remove run_name from kwargs if present to avoid duplicate keyword argument + # The explicit run_name parameter takes precedence + kwargs.pop("run_name", None) + + # Reuse existing active run if one exists, otherwise start a new one + active_run = mlflow.active_run() + if active_run is not None: + run = active_run + else: + run = mlflow.start_run(run_name=run_name, **kwargs) + _active_run_id = run.info.run_id + return run + + +def get_active_run_id() -> Optional[str]: + """Get the active run ID that was started by init().""" + return _active_run_id + + +def _ensure_run_for_logging() -> None: + """Ensure there's an active MLflow run for logging. + + This helper handles async contexts where thread-local run context may be lost. + If no active run exists but we have a stored run ID, it resumes that run. + """ + active_run = mlflow.active_run() + if not active_run and _active_run_id: + # No active run in this thread but we have a stored run ID - resume it + # This can happen in async contexts where thread-local context is lost + # Note: We don't use context manager here because it would end the run on exit + mlflow.start_run(run_id=_active_run_id) + + +def log_params(params: Dict[str, Any]) -> None: + """ + Log parameters to mlflow. Raises MLflowNotAvailableError if mlflow is not installed. + + Args: + params: Dictionary of parameters to log + + Raises: + MLflowNotAvailableError: If mlflow is not installed + """ + check_mlflow_available("log params to mlflow") + # MLflow params must be strings + str_params = {k: str(v) for k, v in params.items()} + + _ensure_run_for_logging() + mlflow.log_params(str_params) + + +def log(data: Dict[str, Any], step: Optional[int] = None) -> None: + """ + Log metrics to mlflow. Raises MLflowNotAvailableError if mlflow is not installed. + + Args: + data: Dictionary of data to log (non-numeric values will be skipped) + step: Optional step number for the metrics + + Raises: + MLflowNotAvailableError: If mlflow is not installed + """ + check_mlflow_available("log to mlflow") + # Filter to only numeric values for metrics + metrics = {} + for k, v in data.items(): + try: + metrics[k] = float(v) + except (ValueError, TypeError): + pass # Skip non-numeric values + if metrics: + _ensure_run_for_logging() + mlflow.log_metrics(metrics, step=step) + + +def finish() -> None: + """ + End the mlflow run. Raises MLflowNotAvailableError if mlflow is not installed. + + Raises: + MLflowNotAvailableError: If mlflow is not installed + """ + global _active_run_id + check_mlflow_available("finish mlflow run") + mlflow.end_run() + _active_run_id = None diff --git a/src/mini_trainer/train.py b/src/mini_trainer/train.py index 48569a2..46d22c0 100644 --- a/src/mini_trainer/train.py +++ b/src/mini_trainer/train.py @@ -8,7 +8,7 @@ from typer import Typer, Option from mini_trainer.async_structured_logger import AsyncStructuredLogger -from mini_trainer import wandb_wrapper +from mini_trainer import wandb_wrapper, mlflow_wrapper from tqdm import tqdm import torch import torch.distributed as dist @@ -689,6 +689,7 @@ def train( save_best_val_loss: bool = False, val_loss_improvement_threshold: float = 0.0, use_wandb: bool = False, + use_mlflow: bool = False, val_data_loader: torch.utils.data.DataLoader | None = None, validation_frequency: int | None = None, ): @@ -745,7 +746,9 @@ def train( world_size = int(os.environ["WORLD_SIZE"]) is_local_main_process = int(os.getenv("LOCAL_RANK", 0)) == 0 metric_logger = AsyncStructuredLogger( - output_dir + f"/training_metrics_{get_node_rank()}.jsonl", use_wandb=use_wandb + output_dir + f"/training_metrics_{get_node_rank()}.jsonl", + use_wandb=use_wandb, + use_mlflow=use_mlflow, ) # initialize variables @@ -1169,6 +1172,14 @@ def main( wandb_entity: Annotated[ str | None, Option(help="Weights & Biases entity/team name") ] = None, + # mlflow parameters + mlflow_tracking_uri: Annotated[ + str | None, Option(help="MLflow tracking server URI") + ] = None, + mlflow_experiment_name: Annotated[ + str | None, Option(help="MLflow experiment name") + ] = None, + mlflow_run_name: Annotated[str | None, Option(help="MLflow run name")] = None, ): # Reproducibility: align with HF Trainer seeding behavior set_seed(seed) @@ -1223,8 +1234,9 @@ def main( osft_output_dtype_torch = parse_dtype(osft_output_dtype) train_dtype_torch = parse_dtype(train_dtype) - # Initialize use_wandb variable + # Initialize logging flags use_wandb = wandb_project is not None + use_mlflow = any([mlflow_tracking_uri, mlflow_experiment_name, mlflow_run_name]) # Log parameters only on rank 0 local_rank = int(os.getenv("LOCAL_RANK", 0)) @@ -1265,6 +1277,9 @@ def main( "wandb_project": wandb_project, "wandb_run_name": wandb_run_name, "wandb_entity": wandb_entity, + "mlflow_tracking_uri": mlflow_tracking_uri, + "mlflow_experiment_name": mlflow_experiment_name, + "mlflow_run_name": mlflow_run_name, "LOCAL_RANK": local_rank, "GLOBAL_RANK": global_rank, "NODE_RANK": node_rank, @@ -1288,6 +1303,19 @@ def main( ) log_rank_0(f"Initialized wandb project: {wandb_project}") + # Initialize mlflow with the same params config + # Only init on global rank 0 to avoid multiple runs in multi-node setups + if use_mlflow and global_rank == 0: + mlflow_wrapper.init( + tracking_uri=mlflow_tracking_uri, + experiment_name=mlflow_experiment_name, + run_name=mlflow_run_name + or wandb_run_name, # fallback to wandb_run_name + ) + # Log hyperparameters + mlflow_wrapper.log_params(params) + log_rank_0(f"Initialized mlflow with tracking URI: {mlflow_tracking_uri}") + params_path = output_path / "training_params.json" with open(params_path, "w") as f: json.dump(params, f, indent=4) @@ -1398,6 +1426,7 @@ def main( save_best_val_loss=save_best_val_loss, val_loss_improvement_threshold=val_loss_improvement_threshold, use_wandb=use_wandb, + use_mlflow=use_mlflow, val_data_loader=val_data_loader, validation_frequency=validation_frequency, ) @@ -1405,6 +1434,9 @@ def main( # once done, tear down distributed environment if use_wandb: wandb_wrapper.finish() + # Only finish mlflow on global rank 0 (where we started it) + if use_mlflow and torch.distributed.get_rank() == 0: + mlflow_wrapper.finish() destroy_distributed_environment() diff --git a/src/mini_trainer/training_types.py b/src/mini_trainer/training_types.py index 5c9fcf9..6615a18 100644 --- a/src/mini_trainer/training_types.py +++ b/src/mini_trainer/training_types.py @@ -225,6 +225,17 @@ class TrainingArgs: default=None, metadata={"help": "Weights & Biases entity/team name."} ) + # MLflow integration + mlflow_tracking_uri: Optional[str] = field( + default=None, metadata={"help": "MLflow tracking server URI."} + ) + mlflow_experiment_name: Optional[str] = field( + default=None, metadata={"help": "MLflow experiment name."} + ) + mlflow_run_name: Optional[str] = field( + default=None, metadata={"help": "MLflow run name."} + ) + # validation validation_split: float = field( default=0.0,