Skip to content

Commit 823db82

Browse files
Merge pull request #10 from RadarML/dev/upgrade
CLI tools for managing results
2 parents 1bd63a1 + 1c2a397 commit 823db82

File tree

8 files changed

+281
-2
lines changed

8 files changed

+281
-2
lines changed

docs/cli.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,25 @@
2828
separate_signature: false
2929
show_root_heading: false
3030
show_root_toc_entry: false
31+
32+
## `nrdk upgrade-config`
33+
34+
:::nrdk._cli.cli_upgrade
35+
options:
36+
heading_level: 2
37+
show_symbol_type_heading: false
38+
show_signature: false
39+
separate_signature: false
40+
show_root_heading: false
41+
show_root_toc_entry: false
42+
43+
## `nrdk validate`
44+
45+
:::nrdk._cli.cli_validate
46+
options:
47+
heading_level: 2
48+
show_symbol_type_heading: false
49+
show_signature: false
50+
separate_signature: false
51+
show_root_heading: false
52+
show_root_toc_entry: false

docs/index.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ Built around typed, high modular interfaces, the NRDK is designed to reduce the
1212

1313
neural radar development kit core framework
1414

15+
- :octicons-terminal-16: [`> nrdk ...`](cli.md)
16+
17+
---
18+
19+
cli tools for working with NRDK output conventions
20+
1521
- :material-book-open-page-variant-outline: [`grt`](grt/index.md)
1622

1723
---

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "nrdk"
7-
version = "0.1.3"
7+
version = "0.1.4"
88
authors = [
99
{ name="Tianshu Huang", email="tianshu2@andrew.cmu.edu" },
1010
]

src/nrdk/_cli/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from .export import cli_export
88
from .inspect import cli_inspect
9+
from .upgrade import cli_upgrade
10+
from .validate import cli_validate
911

1012

1113
def make_annotation(name, func):
@@ -22,6 +24,8 @@ def cli_main() -> None:
2224
commands = {
2325
"inspect": cli_inspect,
2426
"export": cli_export,
27+
"upgrade-config": cli_upgrade,
28+
"validate": cli_validate,
2529
}
2630

2731
return tyro.cli(Union[ # type: ignore

src/nrdk/_cli/upgrade.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
"""Upgrade one implementation to another in hydra configs."""
2+
3+
import os
4+
import re
5+
6+
import numpy as np
7+
from rich import print
8+
from rich.columns import Columns
9+
from rich.panel import Panel
10+
11+
from nrdk.framework import Result
12+
13+
14+
def _format_context(
15+
context: str, line_num: int | np.integer,
16+
start: int | np.integer, end: int | np.integer
17+
) -> str:
18+
return '\n'.join([
19+
f"{'>>>' if n == line_num else ' '} {line}"
20+
for n, line in zip(range(start, end), context.split('\n'))
21+
])
22+
23+
24+
def _search(
25+
text: str, pattern: str | re.Pattern, context_size: int = 2
26+
) -> list[tuple[int, str]]:
27+
if isinstance(pattern, str):
28+
pattern = re.compile(pattern)
29+
30+
newlines = np.where(np.frombuffer(
31+
text.encode('utf-8'), dtype=np.uint8) == ord('\n'))[0]
32+
newlines = np.concatenate([[-1], newlines, [len(text)]])
33+
34+
matches = []
35+
search_start = 0
36+
while True:
37+
match = pattern.search(text, search_start)
38+
if not match:
39+
break
40+
41+
line_num = np.searchsorted(newlines, match.start(), side='right') - 1
42+
43+
start = max(0, line_num - context_size)
44+
end = min(len(newlines) - 1, line_num + 1 + context_size)
45+
46+
context = text[newlines[start] + 1:newlines[end]]
47+
matches.append(
48+
(line_num, _format_context(context, line_num, start, end)))
49+
50+
search_start = match.end()
51+
52+
return matches
53+
54+
55+
def cli_upgrade(
56+
target: str, /, to: str | None = None,
57+
dry_run: bool = False, path: str = ".", follow_symlinks: bool = False
58+
) -> None:
59+
"""Upgrade implementation references in hydra configs.
60+
61+
!!! info "Usage"
62+
63+
First test with a dry run:
64+
```sh
65+
nrdk upgrade-config <target> --path ./results --dry-run
66+
```
67+
If you're happy with what you see, you can then run the actual upgrade:
68+
```sh
69+
nrdk upgrade-config <target> <to> --path ./results
70+
```
71+
72+
!!! danger
73+
74+
This is a potentially destructive operation! Always run with
75+
`--dry-run` first, and make sure that `to` does not overlap with
76+
any other existing implementations in your configs.
77+
78+
You can also use the `upgrade-config` tool to check for this overlap
79+
first:
80+
```sh
81+
nrdk upgrade-config <to> --path ./results --dry-run
82+
# Shouldn't return any of the config files you are planning to upgrade
83+
```
84+
85+
For each valid [results directory][nrdk.framework.Result] in the specified
86+
`path`, search for all `_target_` fields in the hydra config, and replace
87+
any occurrences of `from` with `to`.
88+
89+
Args:
90+
target: full path name of the implementation to replace.
91+
to: full path name of the implementation to replace with.
92+
dry_run: if `True`, only log the changes that would be made, and do not
93+
actually modify any files.
94+
path: path to search for results directories.
95+
follow_symlinks: whether to follow symlinks when searching for results.
96+
"""
97+
pattern = re.compile(rf"_target_\s*:\s*{re.escape(target)}(?=\s|$)")
98+
results = Result.find(path, follow_symlinks=follow_symlinks)
99+
100+
if dry_run:
101+
all_matches = {}
102+
for r in results:
103+
config_path = os.path.join(r, ".hydra", "config.yaml")
104+
if os.path.exists(config_path):
105+
with open(config_path, "r") as f:
106+
config = f.read()
107+
108+
matches = _search(config, pattern, context_size=2)
109+
for line_num, context in matches:
110+
if context not in all_matches:
111+
all_matches[context] = []
112+
all_matches[context].append((config_path, line_num))
113+
114+
for k, v in all_matches.items():
115+
print(
116+
f"Found {len(v)} occurrence(s) of '{target}' "
117+
f"with this context:")
118+
print(Panel(k))
119+
print(Columns(
120+
f"{os.path.relpath(config_path, path)}:{line_num}"
121+
for config_path, line_num in v))
122+
print()
123+
124+
else:
125+
if to is None:
126+
raise ValueError("Must specify `to` when not doing a dry run.")
127+
128+
for r in results:
129+
config_path = os.path.join(r, ".hydra", "config.yaml")
130+
if os.path.exists(config_path):
131+
with open(config_path, "r") as f:
132+
config = f.read()
133+
134+
n = re.findall(pattern, config)
135+
if n:
136+
print(f"Upgrading {len(n)} occurrence(s): {config_path}")
137+
new_config = re.sub(pattern, f"_target_: {to}", config)
138+
with open(config_path, "w") as f:
139+
f.write(new_config)

src/nrdk/_cli/validate.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""Validate results directories."""
2+
3+
import os
4+
5+
from rich import print
6+
from rich.table import Table
7+
8+
from nrdk.framework import Result
9+
10+
11+
def cli_validate(
12+
path: str, /, follow_symlinks: bool = False, show_all: bool = False,
13+
) -> None:
14+
"""Validate results directories.
15+
16+
!!! info "Usage"
17+
18+
```sh
19+
nrdk validate <path> --follow_symlinks
20+
```
21+
22+
For each valid [results directory][nrdk.framework.Result] in the specified
23+
`path`, check that all expected files are present:
24+
25+
| File | Description |
26+
| ----------------------- | --------------------------------------------- |
27+
| `.hydra/config.yaml` | Hydra configuration used for the run. |
28+
| `checkpoints/last.ckpt` | Last model checkpoint saved during training. |
29+
| `eval/` | Directory containing evaluation outputs. |
30+
| `checkpoints.yaml` | Checkpoint index; absence indicates a crashed run. |
31+
| `events.out.tfevents.*` | Tensorboard log files. |
32+
33+
Args:
34+
path: path to search for results directories.
35+
follow_symlinks: whether to follow symlinks when searching for results.
36+
show_all: show all results instead of just results with missing files.
37+
"""
38+
results = Result.find(path, follow_symlinks=follow_symlinks, strict=False)
39+
40+
_check_files = [
41+
".hydra/config.yaml",
42+
"checkpoints/last.ckpt",
43+
"eval",
44+
"checkpoints.yaml",
45+
]
46+
_status = {
47+
True: u'[green]\u2713[/green]',
48+
False: u'[bold red]\u2718[/bold red]',
49+
}
50+
51+
missing = 0
52+
53+
table = Table()
54+
table.add_column("path", justify="right", style="cyan")
55+
table.add_column("config.yaml", justify="left")
56+
table.add_column("last.ckpt", justify="left")
57+
table.add_column("eval", justify="left")
58+
table.add_column("checkpoints.yaml", justify="left")
59+
table.add_column("tfevents", justify="left")
60+
61+
for r in results:
62+
row = [
63+
os.path.exists(os.path.join(r, file))
64+
for file in _check_files
65+
] + [any(
66+
fname.startswith("events.out.tfevents.")
67+
for fname in os.listdir(r)
68+
)]
69+
if not all(row):
70+
missing += 1
71+
if show_all or not all(row):
72+
table.add_row(os.path.relpath(r, path), *[_status[x] for x in row])
73+
74+
if missing > 0:
75+
print(
76+
f"Found {len(results)} results directories with {missing} "
77+
f"incomplete results.")
78+
else:
79+
print(f"All {len(results)} results directories are complete.")
80+
81+
print(table)

src/nrdk/framework/result.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,33 @@ def __init__(self, path: str, validate: bool = True) -> None:
5050
if validate:
5151
self.validate(path)
5252

53+
@staticmethod
54+
def find(
55+
path: str, follow_symlinks: bool = False, strict: bool = True
56+
) -> list[str]:
57+
"""Find all results directories under the given path.
58+
59+
Args:
60+
path: path to search under.
61+
follow_symlinks: if True, follow symlinks when searching.
62+
strict: if `True`, only return directories that pass validation;
63+
if `False`, return directories with any `.hydra` folder or
64+
`checkpoints.yaml` file.
65+
66+
Returns:
67+
List of paths to results directories (that contain a
68+
`checkpoints.yaml` and `.hydra` folder.)
69+
"""
70+
results = []
71+
for root, dirs, files in os.walk(path, followlinks=follow_symlinks):
72+
if strict:
73+
if ".hydra" in dirs and "checkpoints.yaml" in files:
74+
results.append(root)
75+
else:
76+
if ".hydra" in dirs or "checkpoints.yaml" in files:
77+
results.append(root)
78+
return results
79+
5380
@staticmethod
5481
def validate(path: str) -> None:
5582
"""Validate that the given path is a valid results directory.

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)