-
Notifications
You must be signed in to change notification settings - Fork 17
Add MLflow logging support #66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a997584
642acc5
7830772
e86bbba
492c88b
a6fb893
60f4d9a
b9a5992
91ddadc
abc67b0
d72b444
db2c492
33fa513
0eb1296
bd799bc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,173 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """ | ||
| Wrapper for optional mlflow imports that provides consistent error handling | ||
| across all processes when mlflow is not installed. | ||
| """ | ||
|
|
||
| import logging | ||
| import os | ||
| from typing import Any, Dict, Optional | ||
|
|
||
| # Try to import mlflow | ||
| try: | ||
| import mlflow | ||
|
|
||
| MLFLOW_AVAILABLE = True | ||
| except ImportError: | ||
| MLFLOW_AVAILABLE = False | ||
| mlflow = None | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| # Store the active run ID to ensure we can resume the run if needed | ||
| # This is needed because async logging may lose the thread-local run context | ||
| _active_run_id: Optional[str] = None | ||
|
|
||
|
|
||
| class MLflowNotAvailableError(ImportError): | ||
| """Raised when mlflow functions are called but mlflow is not installed.""" | ||
|
|
||
| pass | ||
|
|
||
|
|
||
| def check_mlflow_available(operation: str) -> None: | ||
| """Check if mlflow is available, raise error if not.""" | ||
| if not MLFLOW_AVAILABLE: | ||
| error_msg = ( | ||
| f"Attempted to {operation} but mlflow is not installed. " | ||
| "Please install mlflow with: pip install mlflow" | ||
| ) | ||
| logger.error(error_msg) | ||
| raise MLflowNotAvailableError(error_msg) | ||
|
Comment on lines
+28
to
+42
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: find . -name "pyproject.toml" -type f | head -5Repository: Red-Hat-AI-Innovation-Team/mini_trainer Length of output: 99 🏁 Script executed: cat pyproject.tomlRepository: Red-Hat-AI-Innovation-Team/mini_trainer Length of output: 2785 🏁 Script executed: sed -n '28,42p' src/mini_trainer/mlflow_wrapper.pyRepository: Red-Hat-AI-Innovation-Team/mini_trainer Length of output: 618 Align MLflow install hint with distribution name. Line 39 references 🔧 Suggested update- "Please install mlflow with: pip install 'mini-trainer[mlflow]'"
+ "Please install mlflow with: pip install 'rhai-innovation-mini-trainer[mlflow]'"🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| def init( | ||
| tracking_uri: Optional[str] = None, | ||
| experiment_name: Optional[str] = None, | ||
| run_name: Optional[str] = None, | ||
| **kwargs, | ||
| ) -> Any: | ||
| """ | ||
| Initialize an mlflow run. Raises MLflowNotAvailableError if mlflow is not installed. | ||
|
|
||
| Configuration follows a precedence hierarchy: | ||
| 1. Explicit kwargs (highest priority) | ||
| 2. Environment variables (MLFLOW_TRACKING_URI, MLFLOW_EXPERIMENT_NAME) | ||
| 3. MLflow defaults (lowest priority) | ||
|
|
||
| Args: | ||
| tracking_uri: MLflow tracking server URI (e.g., "http://localhost:5000"). | ||
| Falls back to MLFLOW_TRACKING_URI environment variable if not provided. | ||
| experiment_name: Name of the experiment. | ||
| Falls back to MLFLOW_EXPERIMENT_NAME environment variable if not provided. | ||
| run_name: Name of the run | ||
| **kwargs: Additional arguments to pass to mlflow.start_run | ||
|
|
||
| Returns: | ||
| mlflow.ActiveRun object if successful | ||
|
|
||
| Raises: | ||
| MLflowNotAvailableError: If mlflow is not installed | ||
| """ | ||
| global _active_run_id | ||
| check_mlflow_available("initialize mlflow") | ||
|
|
||
| # Apply kwarg > env var precedence for tracking_uri | ||
| effective_tracking_uri = tracking_uri or os.environ.get("MLFLOW_TRACKING_URI") | ||
| if effective_tracking_uri: | ||
| mlflow.set_tracking_uri(effective_tracking_uri) | ||
|
|
||
| # Apply kwarg > env var precedence for experiment_name | ||
| effective_experiment_name = experiment_name or os.environ.get( | ||
| "MLFLOW_EXPERIMENT_NAME" | ||
| ) | ||
| if effective_experiment_name: | ||
| mlflow.set_experiment(effective_experiment_name) | ||
|
|
||
| # Remove run_name from kwargs if present to avoid duplicate keyword argument | ||
| # The explicit run_name parameter takes precedence | ||
| kwargs.pop("run_name", None) | ||
|
|
||
| # Reuse existing active run if one exists, otherwise start a new one | ||
| active_run = mlflow.active_run() | ||
| if active_run is not None: | ||
| run = active_run | ||
| else: | ||
| run = mlflow.start_run(run_name=run_name, **kwargs) | ||
| _active_run_id = run.info.run_id | ||
| return run | ||
RobotSail marked this conversation as resolved.
Show resolved
Hide resolved
RobotSail marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def get_active_run_id() -> Optional[str]: | ||
| """Get the active run ID that was started by init().""" | ||
| return _active_run_id | ||
|
Comment on lines
+34
to
+104
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: find . -name "mlflow_wrapper.py" -type fRepository: Red-Hat-AI-Innovation-Team/mini_trainer Length of output: 119 🏁 Script executed: cat -n src/mini_trainer/mlflow_wrapper.py | head -150Repository: Red-Hat-AI-Innovation-Team/mini_trainer Length of output: 5266 🌐 Web query:
💡 Result:
Sources: Citations:
Add The Pass Fix# In log_params() - line 96
if _active_run_id:
with mlflow.start_run(run_id=_active_run_id, nested=True):
mlflow.log_params(str_params)
# In log() - line 125
if _active_run_id:
with mlflow.start_run(run_id=_active_run_id, nested=True):
mlflow.log_metrics(metrics, step=step)🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| def _ensure_run_for_logging() -> None: | ||
| """Ensure there's an active MLflow run for logging. | ||
|
|
||
| This helper handles async contexts where thread-local run context may be lost. | ||
| If no active run exists but we have a stored run ID, it resumes that run. | ||
| """ | ||
| active_run = mlflow.active_run() | ||
| if not active_run and _active_run_id: | ||
| # No active run in this thread but we have a stored run ID - resume it | ||
| # This can happen in async contexts where thread-local context is lost | ||
| # Note: We don't use context manager here because it would end the run on exit | ||
| mlflow.start_run(run_id=_active_run_id) | ||
|
|
||
|
|
||
| def log_params(params: Dict[str, Any]) -> None: | ||
| """ | ||
| Log parameters to mlflow. Raises MLflowNotAvailableError if mlflow is not installed. | ||
|
|
||
| Args: | ||
| params: Dictionary of parameters to log | ||
|
|
||
| Raises: | ||
| MLflowNotAvailableError: If mlflow is not installed | ||
| """ | ||
| check_mlflow_available("log params to mlflow") | ||
| # MLflow params must be strings | ||
| str_params = {k: str(v) for k, v in params.items()} | ||
|
|
||
| _ensure_run_for_logging() | ||
| mlflow.log_params(str_params) | ||
|
|
||
|
|
||
| def log(data: Dict[str, Any], step: Optional[int] = None) -> None: | ||
| """ | ||
| Log metrics to mlflow. Raises MLflowNotAvailableError if mlflow is not installed. | ||
|
|
||
| Args: | ||
| data: Dictionary of data to log (non-numeric values will be skipped) | ||
| step: Optional step number for the metrics | ||
|
|
||
| Raises: | ||
| MLflowNotAvailableError: If mlflow is not installed | ||
| """ | ||
| check_mlflow_available("log to mlflow") | ||
| # Filter to only numeric values for metrics | ||
| metrics = {} | ||
| for k, v in data.items(): | ||
| try: | ||
| metrics[k] = float(v) | ||
| except (ValueError, TypeError): | ||
| pass # Skip non-numeric values | ||
| if metrics: | ||
| _ensure_run_for_logging() | ||
| mlflow.log_metrics(metrics, step=step) | ||
|
|
||
|
|
||
| def finish() -> None: | ||
| """ | ||
| End the mlflow run. Raises MLflowNotAvailableError if mlflow is not installed. | ||
|
|
||
| Raises: | ||
| MLflowNotAvailableError: If mlflow is not installed | ||
| """ | ||
| global _active_run_id | ||
| check_mlflow_available("finish mlflow run") | ||
| mlflow.end_run() | ||
| _active_run_id = None | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Repository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 1447
🏁 Script executed:
Repository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 2126
🏁 Script executed:
Repository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 628
🏁 Script executed:
Repository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 988
🏁 Script executed:
Repository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 253
🏁 Script executed:
Repository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 6208
🏁 Script executed:
Repository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 5599
MLflow experiment and run name args should be independent of tracking URI.
Lines 160–165 only append
--mlflow-experiment-nameand--mlflow-run-namewhenmlflow_tracking_uriis set. However, the training code enables MLflow if any of the three MLflow args are provided (use_mlflow = any([mlflow_tracking_uri, mlflow_experiment_name, mlflow_run_name])), and MLflow supports using the default local backend without an explicit tracking URI. Move these two conditionals outside themlflow_tracking_uriblock so they append independently.🛠️ Suggested fix
🤖 Prompt for AI Agents