Skip to content

Commit f1f67db

Browse files
committed
Update cli to handle multi-file optimization
1 parent 2817886 commit f1f67db

File tree

7 files changed

+470
-100
lines changed

7 files changed

+470
-100
lines changed

tests/test_artifacts.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
"""Tests for run artifact persistence and path sanitization."""
2+
3+
import json
4+
5+
import pytest
6+
7+
from weco.artifacts import RunArtifacts, _sanitize_artifact_path
8+
9+
10+
@pytest.fixture
11+
def artifacts(tmp_path):
12+
return RunArtifacts(log_dir=str(tmp_path), run_id="test-run")
13+
14+
15+
def _read_manifest(path):
16+
return json.loads(path.read_text())
17+
18+
19+
@pytest.mark.parametrize(
20+
("raw_path", "expected_parts"),
21+
[
22+
("model.py", ("model.py",)),
23+
("src/model.py", ("src", "model.py")),
24+
("./src/model.py", ("src", "model.py")),
25+
("/absolute/path.py", ("absolute", "path.py")),
26+
("src\\utils\\helper.py", ("src", "utils", "helper.py")),
27+
("../../etc/passwd", ("etc", "passwd")),
28+
("", ("unnamed_file",)),
29+
("../../..", ("unnamed_file",)),
30+
],
31+
)
32+
def test_sanitize_artifact_path(raw_path, expected_parts):
33+
assert _sanitize_artifact_path(raw_path).parts == expected_parts
34+
35+
36+
def test_save_step_code_writes_files_and_manifest(artifacts):
37+
bundle = artifacts.save_step_code(
38+
step=3,
39+
file_map={
40+
"src/model.py": "class Model: pass",
41+
"src/utils.py": "def helper(): pass",
42+
},
43+
)
44+
45+
assert bundle == artifacts.root / "steps" / "3"
46+
assert (bundle / "files" / "src" / "model.py").read_text() == "class Model: pass"
47+
assert (bundle / "files" / "src" / "utils.py").read_text() == "def helper(): pass"
48+
49+
manifest = _read_manifest(bundle / "manifest.json")
50+
assert manifest["type"] == "step_code_snapshot"
51+
assert manifest["step"] == 3
52+
assert manifest["file_count"] == 2
53+
assert [file_entry["path"] for file_entry in manifest["files"]] == ["src/model.py", "src/utils.py"]
54+
assert [file_entry["artifact_path"] for file_entry in manifest["files"]] == ["src/model.py", "src/utils.py"]
55+
56+
57+
def test_save_step_code_keeps_steps_independent(artifacts):
58+
artifacts.save_step_code(step=0, file_map={"f.py": "v1"})
59+
artifacts.save_step_code(step=1, file_map={"f.py": "v2"})
60+
61+
assert (artifacts.root / "steps" / "0" / "files" / "f.py").read_text() == "v1"
62+
assert (artifacts.root / "steps" / "1" / "files" / "f.py").read_text() == "v2"
63+
64+
65+
def test_save_best_code_writes_manifest_without_step(artifacts):
66+
bundle = artifacts.save_best_code({"model.py": "optimized = True"})
67+
68+
assert bundle == artifacts.root / "best"
69+
assert (bundle / "files" / "model.py").read_text() == "optimized = True"
70+
71+
manifest = _read_manifest(bundle / "manifest.json")
72+
assert manifest["type"] == "best_code_snapshot"
73+
assert manifest["file_count"] == 1
74+
assert "step" not in manifest
75+
76+
77+
def test_save_execution_output_writes_step_file_and_jsonl_index(artifacts):
78+
artifacts.save_execution_output(step=0, output="first")
79+
artifacts.save_execution_output(step=1, output="second")
80+
81+
assert (artifacts.root / "outputs" / "step_0.out.txt").read_text() == "first"
82+
assert (artifacts.root / "outputs" / "step_1.out.txt").read_text() == "second"
83+
84+
lines = (artifacts.root / "exec_output.jsonl").read_text().strip().split("\n")
85+
assert len(lines) == 2
86+
first_entry = json.loads(lines[0])
87+
second_entry = json.loads(lines[1])
88+
89+
assert first_entry["step"] == 0
90+
assert first_entry["output_file"] == "outputs/step_0.out.txt"
91+
assert first_entry["output_length"] == len("first")
92+
assert second_entry["step"] == 1
93+
assert second_entry["output_file"] == "outputs/step_1.out.txt"
94+
assert second_entry["output_length"] == len("second")
95+
96+
97+
def test_root_directory_creation_is_idempotent(tmp_path):
98+
first = RunArtifacts(log_dir=str(tmp_path), run_id="abc-123")
99+
second = RunArtifacts(log_dir=str(tmp_path), run_id="abc-123")
100+
101+
assert first.root == second.root == (tmp_path / "abc-123")
102+
assert first.root.exists()
103+
104+
105+
def test_step_snapshot_sanitizes_path_traversal(artifacts, tmp_path):
106+
artifacts.save_step_code(step=0, file_map={"../../etc/evil.py": "malicious"})
107+
108+
assert not (tmp_path / "etc" / "evil.py").exists()
109+
assert (artifacts.root / "steps" / "0" / "files" / "etc" / "evil.py").exists()
110+
111+
112+
def test_best_snapshot_sanitizes_path_traversal(artifacts, tmp_path):
113+
artifacts.save_best_code({"../../../tmp/evil.py": "malicious"})
114+
115+
assert not (tmp_path.parent / "tmp" / "evil.py").exists()
116+
assert (artifacts.root / "best" / "files" / "tmp" / "evil.py").exists()

weco/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ def _recover_suggest_after_transport_error(
113113

114114
def start_optimization_run(
115115
console: Console,
116-
source_code: str,
117-
source_path: str,
116+
source_code: str | dict[str, str],
117+
source_path: str | None,
118118
evaluation_command: str,
119119
metric_name: str,
120120
maximize: bool,

weco/artifacts.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
"""On-disk artifact management for optimization runs.
2+
3+
Centralizes the directory layout and write logic for all artifacts
4+
produced during an optimization run under .runs/<run_id>/.
5+
"""
6+
7+
import json
8+
import pathlib
9+
from datetime import datetime
10+
11+
12+
def _sanitize_artifact_path(path_value: str) -> pathlib.Path:
13+
"""Convert a source path into a safe relative artifact path.
14+
15+
Strips traversal components (..), absolute prefixes, and Windows
16+
drive letters so that artifacts are always written under the
17+
intended directory.
18+
"""
19+
normalized = path_value.replace("\\", "/")
20+
parts = pathlib.PurePosixPath(normalized).parts
21+
safe_parts: list[str] = []
22+
for part in parts:
23+
if part in ("", ".", "/"):
24+
continue
25+
if part == "..":
26+
continue
27+
if not safe_parts and ":" in part:
28+
part = part.replace(":", "_")
29+
safe_parts.append(part)
30+
31+
if not safe_parts:
32+
return pathlib.Path("unnamed_file")
33+
return pathlib.Path(*safe_parts)
34+
35+
36+
class RunArtifacts:
37+
"""Manages the on-disk artifact layout for a single optimization run.
38+
39+
Layout::
40+
41+
<root>/
42+
steps/<step>/
43+
files/<relative_path> # actual code files
44+
manifest.json # machine-readable index
45+
best/
46+
files/<relative_path>
47+
manifest.json
48+
outputs/
49+
step_<n>.out.txt # execution stdout/stderr
50+
exec_output.jsonl # centralized output index
51+
"""
52+
53+
def __init__(self, log_dir: str, run_id: str) -> None:
54+
self.root = pathlib.Path(log_dir) / run_id
55+
self.root.mkdir(parents=True, exist_ok=True)
56+
57+
# ------------------------------------------------------------------
58+
# Code snapshots
59+
# ------------------------------------------------------------------
60+
61+
def save_step_code(self, step: int, file_map: dict[str, str]) -> pathlib.Path:
62+
"""Write code snapshot + manifest for a given step.
63+
64+
Returns the bundle directory path.
65+
"""
66+
return self._write_code_bundle(file_map, label=("steps", str(step)))
67+
68+
def save_best_code(self, file_map: dict[str, str]) -> pathlib.Path:
69+
"""Write code snapshot + manifest for the best result.
70+
71+
Returns the bundle directory path.
72+
"""
73+
return self._write_code_bundle(file_map, label=("best",))
74+
75+
# ------------------------------------------------------------------
76+
# Execution output
77+
# ------------------------------------------------------------------
78+
79+
def save_execution_output(self, step: int, output: str) -> None:
80+
"""Save execution output as a per-step file and append to the JSONL index."""
81+
timestamp = datetime.now().isoformat()
82+
83+
outputs_dir = self.root / "outputs"
84+
# Keep raw execution output per step for easy local inspection.
85+
outputs_dir.mkdir(parents=True, exist_ok=True)
86+
87+
step_file = outputs_dir / f"step_{step}.out.txt"
88+
# Store full stdout/stderr for this exact step.
89+
step_file.write_text(output, encoding="utf-8")
90+
91+
jsonl_file = self.root / "exec_output.jsonl"
92+
entry = {
93+
"step": step,
94+
"timestamp": timestamp,
95+
"output_file": step_file.relative_to(self.root).as_posix(),
96+
"output_length": len(output),
97+
}
98+
# Append compact metadata so tooling can stream/index outputs.
99+
with open(jsonl_file, "a", encoding="utf-8") as f:
100+
f.write(json.dumps(entry) + "\n")
101+
102+
# ------------------------------------------------------------------
103+
# Internal helpers
104+
# ------------------------------------------------------------------
105+
106+
def _write_code_bundle(
107+
self,
108+
file_map: dict[str, str],
109+
label: tuple[str, ...],
110+
) -> pathlib.Path:
111+
bundle_dir = self.root.joinpath(*label)
112+
files_dir = bundle_dir / "files"
113+
files_dir.mkdir(parents=True, exist_ok=True)
114+
115+
files_manifest: list[dict[str, str | int]] = []
116+
for source_path, content in sorted(file_map.items()):
117+
artifact_rel = _sanitize_artifact_path(source_path)
118+
artifact_path = files_dir / artifact_rel
119+
artifact_path.parent.mkdir(parents=True, exist_ok=True)
120+
artifact_path.write_text(content, encoding="utf-8")
121+
files_manifest.append(
122+
{
123+
"path": source_path,
124+
"artifact_path": artifact_rel.as_posix(),
125+
"bytes": len(content.encode("utf-8")),
126+
}
127+
)
128+
129+
is_step = label[0] == "steps"
130+
manifest: dict = {
131+
"type": "step_code_snapshot" if is_step else "best_code_snapshot",
132+
"created_at": datetime.now().isoformat(),
133+
"file_count": len(files_manifest),
134+
"files": files_manifest,
135+
}
136+
if is_step:
137+
manifest["step"] = int(label[1])
138+
139+
manifest_path = bundle_dir / "manifest.json"
140+
manifest_path.write_text(json.dumps(manifest, indent=2) + "\n", encoding="utf-8")
141+
return bundle_dir

weco/cli.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
RunStartAttemptedEvent,
1616
)
1717
from .utils import check_for_cli_updates, get_default_model, UnrecognizedAPIKeysError, DefaultModelNotFoundError
18-
from .validation import validate_source_file, validate_log_directory, ValidationError, print_validation_error
18+
from .validation import validate_sources, validate_log_directory, ValidationError, print_validation_error
1919

2020

2121
install(show_locals=True)
@@ -55,12 +55,18 @@ def parse_api_keys(api_key_args: list[str] | None) -> dict[str, str]:
5555
# Function to define and return the run_parser (or configure it on a passed subparser object)
5656
# This helps keep main() cleaner and centralizes run command arg definitions.
5757
def configure_run_parser(run_parser: argparse.ArgumentParser) -> None:
58-
run_parser.add_argument(
58+
source_group = run_parser.add_mutually_exclusive_group(required=True)
59+
source_group.add_argument(
5960
"-s",
6061
"--source",
6162
type=str,
62-
required=True,
63-
help="Path to the source code file that will be optimized (e.g., `optimize.py`)",
63+
help="Path to a single source code file to be optimized (e.g., `optimize.py`)",
64+
)
65+
source_group.add_argument(
66+
"--sources",
67+
nargs="+",
68+
type=str,
69+
help="Paths to multiple source code files to be optimized together (e.g., `model.py utils.py config.py`)",
6470
)
6571
run_parser.add_argument(
6672
"-c",
@@ -111,7 +117,7 @@ def configure_run_parser(run_parser: argparse.ArgumentParser) -> None:
111117
run_parser.add_argument(
112118
"--save-logs",
113119
action="store_true",
114-
help="Save execution output to .runs/<run-id>/outputs/step_<n>.out.txt with JSONL index",
120+
help="Save execution output to .runs/<run-id>/outputs/step_<n>.out.txt with JSONL index. Code snapshots are written to .runs/<run-id>/steps/<step>/files and .runs/<run-id>/best/files.",
115121
)
116122
run_parser.add_argument(
117123
"--apply-change",
@@ -263,9 +269,12 @@ def execute_run_command(args: argparse.Namespace) -> None:
263269

264270
ctx = get_event_context()
265271

272+
# Normalize source input so --source follows the same internal path as --sources
273+
source_arg = args.sources if args.sources is not None else [args.source]
274+
266275
# Early validation — fail fast with helpful errors
267276
try:
268-
validate_source_file(args.source)
277+
validate_sources(source_arg)
269278
validate_log_directory(args.log_dir)
270279
except ValidationError as e:
271280
print_validation_error(e, console)
@@ -301,7 +310,7 @@ def execute_run_command(args: argparse.Namespace) -> None:
301310
)
302311

303312
success = optimize(
304-
source=args.source,
313+
source=source_arg,
305314
eval_command=args.eval_command,
306315
metric=args.metric,
307316
goal=args.goal,

0 commit comments

Comments
 (0)