|
| 1 | +import re |
| 2 | +import subprocess |
1 | 3 | from abc import ABC, abstractmethod |
2 | 4 | from collections.abc import Mapping |
| 5 | +from functools import partial |
| 6 | +from pathlib import Path |
3 | 7 | from typing import Any |
| 8 | +from urllib.parse import urlsplit |
4 | 9 |
|
5 | 10 | from omegaconf import DictConfig |
6 | 11 |
|
@@ -32,3 +37,127 @@ def log_tags(self, tag_dict: Mapping[str, Any]) -> None: |
32 | 37 | def close(self) -> None: |
33 | 38 | """Close the logger.""" |
34 | 39 | ... |
| 40 | + |
| 41 | + def _sanitize_remote(self, remote: str) -> str: |
| 42 | + """Sanitize git remote URL to remove potential credentials. |
| 43 | +
|
| 44 | + Args: |
| 45 | + remote: The git remote URL |
| 46 | +
|
| 47 | + Returns: |
| 48 | + Sanitized remote URL without credentials |
| 49 | + """ |
| 50 | + if not remote: |
| 51 | + return remote |
| 52 | + |
| 53 | + # Try URL-like first: http(s)://..., ssh://..., git+https://... |
| 54 | + try: |
| 55 | + parts = urlsplit(remote) |
| 56 | + if parts.scheme: |
| 57 | + # rebuild without username/password, query, fragment |
| 58 | + host = parts.hostname or "" |
| 59 | + port = f":{parts.port}" if parts.port else "" |
| 60 | + path = parts.path or "" |
| 61 | + return f"{parts.scheme}://{host}{port}{path}" |
| 62 | + except Exception: |
| 63 | + pass |
| 64 | + |
| 65 | + # SCP-like: user@host:path |
| 66 | + m = re.match(r"^[^@]+@([^:]+):(.*)$", remote) |
| 67 | + if m: |
| 68 | + host, path = m.groups() |
| 69 | + return f"{host}:{path}" |
| 70 | + |
| 71 | + # Otherwise return as-is |
| 72 | + return remote |
| 73 | + |
| 74 | + def _find_git_root(self, start: Path) -> Path | None: |
| 75 | + """Find the git repository root from a starting path. |
| 76 | +
|
| 77 | + Args: |
| 78 | + start: Starting path to search from |
| 79 | +
|
| 80 | + Returns: |
| 81 | + Path to git repository root, or None if not found |
| 82 | + """ |
| 83 | + try: |
| 84 | + r = subprocess.run( |
| 85 | + ["git", "-C", str(start), "rev-parse", "--show-toplevel"], |
| 86 | + capture_output=True, |
| 87 | + text=True, |
| 88 | + timeout=1, |
| 89 | + ) |
| 90 | + if r.returncode == 0: |
| 91 | + return Path(r.stdout.strip()) |
| 92 | + except Exception: |
| 93 | + pass |
| 94 | + for parent in [start.resolve(), *start.resolve().parents]: |
| 95 | + if (parent / ".git").exists(): |
| 96 | + return parent |
| 97 | + return None |
| 98 | + |
| 99 | + def _get_git_info(self, repo_path: Path) -> dict[str, str]: |
| 100 | + """Get git repository information. |
| 101 | +
|
| 102 | + Args: |
| 103 | + repo_path: Path to the git repository |
| 104 | +
|
| 105 | + Returns: |
| 106 | + Dictionary with git information (commit, branch, dirty state, remote) |
| 107 | + """ |
| 108 | + try: |
| 109 | + # Create a partial function with common arguments |
| 110 | + run = partial(subprocess.run, capture_output=True, text=True, timeout=2.0) |
| 111 | + |
| 112 | + # Get commit hash |
| 113 | + result = run(["git", "-C", str(repo_path), "rev-parse", "HEAD"]) |
| 114 | + commit_full = result.stdout.strip() if result.returncode == 0 else "unknown" |
| 115 | + commit_short = commit_full[:8] if commit_full != "unknown" else "unknown" |
| 116 | + |
| 117 | + # Check if working directory is dirty (has uncommitted changes) |
| 118 | + result = run(["git", "-C", str(repo_path), "diff-index", "--quiet", "HEAD", "--"]) |
| 119 | + is_dirty = result.returncode != 0 |
| 120 | + |
| 121 | + # Get current branch name |
| 122 | + result = run(["git", "-C", str(repo_path), "branch", "--show-current"]) |
| 123 | + branch = result.stdout.strip() if result.returncode == 0 else "unknown" |
| 124 | + |
| 125 | + # Get remote URL and sanitize it |
| 126 | + result = run(["git", "-C", str(repo_path), "remote", "get-url", "origin"]) |
| 127 | + remote_url = result.stdout.strip() if result.returncode == 0 else "unknown" |
| 128 | + remote_url = self._sanitize_remote(remote_url) |
| 129 | + |
| 130 | + return { |
| 131 | + "commit": commit_short, |
| 132 | + "commit_full": commit_full, |
| 133 | + "dirty": str(is_dirty), |
| 134 | + "branch": branch, |
| 135 | + "remote": remote_url, |
| 136 | + } |
| 137 | + except (subprocess.TimeoutExpired, FileNotFoundError, Exception): |
| 138 | + # Return empty dict if git is not available or repo is not a git repo |
| 139 | + return {} |
| 140 | + |
| 141 | + def log_git_info(self) -> None: |
| 142 | + """Log git information for reproducibility. |
| 143 | +
|
| 144 | + Logs git information for both the main repository (where the training |
| 145 | + script is running) and the simplexity library repository. |
| 146 | + """ |
| 147 | + tags = {} |
| 148 | + |
| 149 | + # Track main repository (current working directory) |
| 150 | + main_root = self._find_git_root(Path.cwd()) |
| 151 | + if main_root: |
| 152 | + for k, v in self._get_git_info(main_root).items(): |
| 153 | + tags[f"git.main.{k}"] = v |
| 154 | + |
| 155 | + # Track simplexity repository using __file__ from Logger class |
| 156 | + pkg_dir = Path(__file__).resolve().parent |
| 157 | + simplexity_root = self._find_git_root(pkg_dir) |
| 158 | + if simplexity_root: |
| 159 | + for k, v in self._get_git_info(simplexity_root).items(): |
| 160 | + tags[f"git.simplexity.{k}"] = v |
| 161 | + |
| 162 | + if tags: |
| 163 | + self.log_tags(tags) |
0 commit comments