Skip to content

Commit 418189c

Browse files
authored
Feature/mlflow git tracking (#65)
1 parent 26067b5 commit 418189c

File tree

1 file changed

+129
-0
lines changed

1 file changed

+129
-0
lines changed

simplexity/logging/logger.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1+
import re
2+
import subprocess
13
from abc import ABC, abstractmethod
24
from collections.abc import Mapping
5+
from functools import partial
6+
from pathlib import Path
37
from typing import Any
8+
from urllib.parse import urlsplit
49

510
from omegaconf import DictConfig
611

@@ -32,3 +37,127 @@ def log_tags(self, tag_dict: Mapping[str, Any]) -> None:
3237
def close(self) -> None:
3338
"""Close the logger."""
3439
...
40+
41+
def _sanitize_remote(self, remote: str) -> str:
42+
"""Sanitize git remote URL to remove potential credentials.
43+
44+
Args:
45+
remote: The git remote URL
46+
47+
Returns:
48+
Sanitized remote URL without credentials
49+
"""
50+
if not remote:
51+
return remote
52+
53+
# Try URL-like first: http(s)://..., ssh://..., git+https://...
54+
try:
55+
parts = urlsplit(remote)
56+
if parts.scheme:
57+
# rebuild without username/password, query, fragment
58+
host = parts.hostname or ""
59+
port = f":{parts.port}" if parts.port else ""
60+
path = parts.path or ""
61+
return f"{parts.scheme}://{host}{port}{path}"
62+
except Exception:
63+
pass
64+
65+
# SCP-like: user@host:path
66+
m = re.match(r"^[^@]+@([^:]+):(.*)$", remote)
67+
if m:
68+
host, path = m.groups()
69+
return f"{host}:{path}"
70+
71+
# Otherwise return as-is
72+
return remote
73+
74+
def _find_git_root(self, start: Path) -> Path | None:
75+
"""Find the git repository root from a starting path.
76+
77+
Args:
78+
start: Starting path to search from
79+
80+
Returns:
81+
Path to git repository root, or None if not found
82+
"""
83+
try:
84+
r = subprocess.run(
85+
["git", "-C", str(start), "rev-parse", "--show-toplevel"],
86+
capture_output=True,
87+
text=True,
88+
timeout=1,
89+
)
90+
if r.returncode == 0:
91+
return Path(r.stdout.strip())
92+
except Exception:
93+
pass
94+
for parent in [start.resolve(), *start.resolve().parents]:
95+
if (parent / ".git").exists():
96+
return parent
97+
return None
98+
99+
def _get_git_info(self, repo_path: Path) -> dict[str, str]:
100+
"""Get git repository information.
101+
102+
Args:
103+
repo_path: Path to the git repository
104+
105+
Returns:
106+
Dictionary with git information (commit, branch, dirty state, remote)
107+
"""
108+
try:
109+
# Create a partial function with common arguments
110+
run = partial(subprocess.run, capture_output=True, text=True, timeout=2.0)
111+
112+
# Get commit hash
113+
result = run(["git", "-C", str(repo_path), "rev-parse", "HEAD"])
114+
commit_full = result.stdout.strip() if result.returncode == 0 else "unknown"
115+
commit_short = commit_full[:8] if commit_full != "unknown" else "unknown"
116+
117+
# Check if working directory is dirty (has uncommitted changes)
118+
result = run(["git", "-C", str(repo_path), "diff-index", "--quiet", "HEAD", "--"])
119+
is_dirty = result.returncode != 0
120+
121+
# Get current branch name
122+
result = run(["git", "-C", str(repo_path), "branch", "--show-current"])
123+
branch = result.stdout.strip() if result.returncode == 0 else "unknown"
124+
125+
# Get remote URL and sanitize it
126+
result = run(["git", "-C", str(repo_path), "remote", "get-url", "origin"])
127+
remote_url = result.stdout.strip() if result.returncode == 0 else "unknown"
128+
remote_url = self._sanitize_remote(remote_url)
129+
130+
return {
131+
"commit": commit_short,
132+
"commit_full": commit_full,
133+
"dirty": str(is_dirty),
134+
"branch": branch,
135+
"remote": remote_url,
136+
}
137+
except (subprocess.TimeoutExpired, FileNotFoundError, Exception):
138+
# Return empty dict if git is not available or repo is not a git repo
139+
return {}
140+
141+
def log_git_info(self) -> None:
142+
"""Log git information for reproducibility.
143+
144+
Logs git information for both the main repository (where the training
145+
script is running) and the simplexity library repository.
146+
"""
147+
tags = {}
148+
149+
# Track main repository (current working directory)
150+
main_root = self._find_git_root(Path.cwd())
151+
if main_root:
152+
for k, v in self._get_git_info(main_root).items():
153+
tags[f"git.main.{k}"] = v
154+
155+
# Track simplexity repository using __file__ from Logger class
156+
pkg_dir = Path(__file__).resolve().parent
157+
simplexity_root = self._find_git_root(pkg_dir)
158+
if simplexity_root:
159+
for k, v in self._get_git_info(simplexity_root).items():
160+
tags[f"git.simplexity.{k}"] = v
161+
162+
if tags:
163+
self.log_tags(tags)

0 commit comments

Comments
 (0)