Skip to content

Commit bef8a1f

Browse files
committed
Add tool to keep test-requirements in sync with pre-commit
1 parent c632be5 commit bef8a1f

File tree

2 files changed

+109
-0
lines changed

2 files changed

+109
-0
lines changed

.pre-commit-config.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,13 @@ repos:
5858
pass_filenames: false
5959
additional_dependencies: ["astor", "attrs", "black", "ruff"]
6060
files: ^src\/trio\/_core\/(_run|(_i(o_(common|epoll|kqueue|windows)|nstrumentation)))\.py$
61+
- id: sync-test-requirements
62+
name: synchronize test requirements
63+
language: python
64+
entry: python src/trio/_tools/sync_requirements.py
65+
pass_filenames: false
66+
additional_dependencies: ["pyyaml"]
67+
files: ^(test-requirements\.txt)|(\.pre-commit-config\.yaml)$
6168
- repo: https://github.com/astral-sh/uv-pre-commit
6269
rev: 0.6.6
6370
hooks:
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
#!/usr/bin/env python3
2+
3+
"""Sync Requirements - Automatically upgrade test requirements pinned
4+
versions from pre-commit config file."""
5+
6+
from __future__ import annotations
7+
8+
import sys
9+
from pathlib import Path
10+
from typing import TYPE_CHECKING
11+
12+
from yaml import load as load_yaml
13+
14+
if TYPE_CHECKING:
15+
from collections.abc import Generator
16+
17+
from yaml import CLoader as _CLoader, Loader as _Loader
18+
19+
Loader: type[_CLoader | _Loader]
20+
21+
try:
22+
from yaml import CLoader as Loader
23+
except ImportError:
24+
from yaml import Loader
25+
26+
27+
def yield_pre_commit_version_data(
28+
pre_commit: Path,
29+
) -> Generator[tuple[str, str], None, None]:
30+
"""Yield (name, rev) tuples from pre-commit config file."""
31+
pre_commit_config = load_yaml(pre_commit.read_text(encoding="utf-8"), Loader)
32+
for repo in pre_commit_config["repos"]:
33+
if "repo" not in repo or "rev" not in repo:
34+
continue
35+
url = repo["repo"]
36+
name = url.rsplit("/", 1)[-1]
37+
rev = repo["rev"].removeprefix("v")
38+
yield name, rev
39+
40+
41+
def update_requirements(
42+
requirements: Path,
43+
version_data: dict[str, str],
44+
) -> bool:
45+
"""Return if updated requirements file.
46+
47+
Update requirements file to match versions in version_data."""
48+
changed = False
49+
old_lines = requirements.read_text(encoding="utf-8").splitlines(True)
50+
51+
with requirements.open("w", encoding="utf-8") as file:
52+
for line in old_lines:
53+
# If comment or not version mark line, ignore.
54+
if line.startswith("#") or "==" not in line:
55+
file.write(line)
56+
continue
57+
name, rest = line.split("==", 1)
58+
# Maintain extra markers if they exist
59+
old_version = rest.strip()
60+
extra = "\n"
61+
if " " in rest:
62+
old_version, extra = rest.split(" ", 1)
63+
extra = " " + extra
64+
version = version_data.get(name)
65+
# If does not exist, skip
66+
if version is None:
67+
file.write(line)
68+
continue
69+
# Otherwise might have changed
70+
new_line = f"{name}=={version}{extra}"
71+
if new_line != line:
72+
if not changed:
73+
changed = True
74+
print("Changed test requirements version to match pre-commit")
75+
print(f"{name}=={old_version} -> {name}=={version}")
76+
file.write(new_line)
77+
return changed
78+
79+
80+
def main() -> int:
81+
"""Run program."""
82+
83+
source_root = Path.cwd().absolute()
84+
while not (source_root / "LICENSE").exists():
85+
source_root = source_root.parent
86+
# Double-check we found the right directory
87+
assert (source_root / "LICENSE").exists()
88+
pre_commit = source_root / ".pre-commit-config.yaml"
89+
test_requirements = source_root / "test-requirements.txt"
90+
91+
# Get tool versions from pre-commit
92+
# Get correct names
93+
pre_commit_versions = {
94+
name.removesuffix("-mirror").removesuffix("-pre-commit"): version
95+
for name, version in yield_pre_commit_version_data(pre_commit)
96+
}
97+
changed = update_requirements(test_requirements, pre_commit_versions)
98+
return int(changed)
99+
100+
101+
if __name__ == "__main__":
102+
sys.exit(main())

0 commit comments

Comments
 (0)