Skip to content
133 changes: 133 additions & 0 deletions simplexity/logging/mlflow_logger.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import subprocess
import tempfile
import time
from collections.abc import Mapping
from pathlib import Path
from typing import Any

import dotenv
Expand Down Expand Up @@ -34,6 +36,9 @@ def __init__(
run = self._client.create_run(experiment_id=experiment_id, run_name=run_name)
self._run_id = run.info.run_id

# Automatically log git information for reproducibility
self._log_git_info()

def log_config(self, config: DictConfig, resolve: bool = False) -> None:
"""Log config to MLflow."""
with tempfile.TemporaryDirectory() as temp_dir:
Expand Down Expand Up @@ -94,3 +99,131 @@ def close(self):
def _log_batch(self, **kwargs: Any) -> None:
"""Log arbitrary data to MLflow."""
self._client.log_batch(self._run_id, **kwargs, synchronous=False)

def _get_git_info(self, repo_path: Path) -> dict[str, str]:
"""Get git repository information.

Args:
repo_path: Path to the git repository

Returns:
Dictionary with git information (commit, branch, dirty state, remote)
"""
try:
# Get commit hash
result = subprocess.run(
["git", "rev-parse", "HEAD"], cwd=repo_path, capture_output=True, text=True, timeout=2
)
commit_full = result.stdout.strip() if result.returncode == 0 else "unknown"
commit_short = commit_full[:8] if commit_full != "unknown" else "unknown"

# Check if working directory is dirty (has uncommitted changes)
result = subprocess.run(
["git", "status", "--porcelain"], cwd=repo_path, capture_output=True, text=True, timeout=2
)
is_dirty = bool(result.stdout.strip()) if result.returncode == 0 else False

# Get current branch name
result = subprocess.run(
["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd=repo_path, capture_output=True, text=True, timeout=2
)
branch = result.stdout.strip() if result.returncode == 0 else "unknown"

# Get remote URL
result = subprocess.run(
["git", "config", "--get", "remote.origin.url"],
cwd=repo_path,
capture_output=True,
text=True,
timeout=2,
)
remote_url = result.stdout.strip() if result.returncode == 0 else "unknown"

return {
"commit": commit_short,
"commit_full": commit_full,
"dirty": str(is_dirty),
"branch": branch,
"remote": remote_url,
}
except (subprocess.TimeoutExpired, FileNotFoundError, Exception):
# Return empty dict if git is not available or repo is not a git repo
return {}

def _log_git_info(self) -> None:
"""Automatically log git information for reproducibility.

Logs git information for both the main repository (where the training
script is running) and the simplexity library repository.
"""
tags = {}

# Track main repository (current working directory)
main_repo_info = self._get_git_info(Path.cwd())
if main_repo_info:
for key, value in main_repo_info.items():
tags[f"git.main.{key}"] = value

# Track simplexity repository
try:
import simplexity

# Try multiple ways to find simplexity path
simplexity_path = None

# Method 1: Use __file__ if available
file_attr = getattr(simplexity, "__file__", None)
if file_attr:
simplexity_path = Path(file_attr).parent.parent
# Method 2: Use __path__ for namespace packages
else:
path_attr = getattr(simplexity, "__path__", None)
if path_attr:
# path_attr might be a _NamespacePath or similar iterable
for path in path_attr:
if path:
simplexity_path = Path(path).parent
break
# Method 3: Use the module spec
if not simplexity_path:
import importlib.util

spec = importlib.util.find_spec("simplexity")
if spec and spec.origin:
simplexity_path = Path(spec.origin).parent.parent

if simplexity_path and simplexity_path.exists():
simplexity_info = self._get_git_info(simplexity_path)
if simplexity_info:
for key, value in simplexity_info.items():
tags[f"git.simplexity.{key}"] = value
except (ImportError, AttributeError, TypeError):
pass

# Log all git tags if we found any
if tags:
self.log_tags(tags)

def log_storage_info(self, persister: Any) -> None:
"""Log model storage information for tracking.

Args:
persister: Model persister object (S3Persister, LocalPersister, etc.)
"""
tags = {}

# Check if it's an S3Persister
if hasattr(persister, "bucket") and hasattr(persister, "prefix"):
tags["storage.type"] = "s3"
tags["storage.location"] = f"s3://{persister.bucket}/{persister.prefix}"
tags["storage.bucket"] = persister.bucket
tags["storage.prefix"] = persister.prefix
# Check if it's a LocalPersister or has a directory attribute
elif hasattr(persister, "directory"):
tags["storage.type"] = "local"
tags["storage.location"] = str(Path(persister.directory).absolute())
else:
tags["storage.type"] = "unknown"

if tags:
self.log_tags(tags)
Loading