Skip to content

Commit 458ed51

Browse files
authored
Merge pull request #112 from WecoAI/feature/multifile
Update cli to handle multi-file optimization
2 parents 2817886 + 2124adb commit 458ed51

File tree

7 files changed

+453
-104
lines changed

7 files changed

+453
-104
lines changed

tests/test_artifacts.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
"""Tests for run artifact persistence and path sanitization."""
2+
3+
import json
4+
5+
import pytest
6+
7+
from weco.artifacts import RunArtifacts, _sanitize_artifact_path
8+
9+
10+
@pytest.fixture
11+
def artifacts(tmp_path):
12+
return RunArtifacts(log_dir=str(tmp_path), run_id="test-run")
13+
14+
15+
def _read_manifest(path):
16+
return json.loads(path.read_text())
17+
18+
19+
@pytest.mark.parametrize(
20+
("raw_path", "expected_parts"),
21+
[
22+
("model.py", ("model.py",)),
23+
("src/model.py", ("src", "model.py")),
24+
("./src/model.py", ("src", "model.py")),
25+
("/absolute/path.py", ("absolute", "path.py")),
26+
("src\\utils\\helper.py", ("src", "utils", "helper.py")),
27+
("../../etc/passwd", ("etc", "passwd")),
28+
("", ("unnamed_file",)),
29+
("../../..", ("unnamed_file",)),
30+
],
31+
)
32+
def test_sanitize_artifact_path(raw_path, expected_parts):
33+
assert _sanitize_artifact_path(raw_path).parts == expected_parts
34+
35+
36+
def test_save_step_code_writes_files_and_manifest(artifacts):
37+
bundle = artifacts.save_step_code(
38+
step=3, file_map={"src/model.py": "class Model: pass", "src/utils.py": "def helper(): pass"}
39+
)
40+
41+
assert bundle == artifacts.root / "steps" / "3"
42+
assert (bundle / "files" / "src" / "model.py").read_text() == "class Model: pass"
43+
assert (bundle / "files" / "src" / "utils.py").read_text() == "def helper(): pass"
44+
45+
manifest = _read_manifest(bundle / "manifest.json")
46+
assert manifest["type"] == "step_code_snapshot"
47+
assert manifest["step"] == 3
48+
assert manifest["file_count"] == 2
49+
assert [file_entry["path"] for file_entry in manifest["files"]] == ["src/model.py", "src/utils.py"]
50+
assert [file_entry["artifact_path"] for file_entry in manifest["files"]] == ["src/model.py", "src/utils.py"]
51+
52+
53+
def test_save_step_code_keeps_steps_independent(artifacts):
54+
artifacts.save_step_code(step=0, file_map={"f.py": "v1"})
55+
artifacts.save_step_code(step=1, file_map={"f.py": "v2"})
56+
57+
assert (artifacts.root / "steps" / "0" / "files" / "f.py").read_text() == "v1"
58+
assert (artifacts.root / "steps" / "1" / "files" / "f.py").read_text() == "v2"
59+
60+
61+
def test_save_best_code_writes_manifest_without_step(artifacts):
62+
bundle = artifacts.save_best_code({"model.py": "optimized = True"})
63+
64+
assert bundle == artifacts.root / "best"
65+
assert (bundle / "files" / "model.py").read_text() == "optimized = True"
66+
67+
manifest = _read_manifest(bundle / "manifest.json")
68+
assert manifest["type"] == "best_code_snapshot"
69+
assert manifest["file_count"] == 1
70+
assert "step" not in manifest
71+
72+
73+
def test_save_execution_output_writes_step_file_and_jsonl_index(artifacts):
74+
artifacts.save_execution_output(step=0, output="first")
75+
artifacts.save_execution_output(step=1, output="second")
76+
77+
assert (artifacts.root / "outputs" / "step_0.out.txt").read_text() == "first"
78+
assert (artifacts.root / "outputs" / "step_1.out.txt").read_text() == "second"
79+
80+
lines = (artifacts.root / "exec_output.jsonl").read_text().strip().split("\n")
81+
assert len(lines) == 2
82+
first_entry = json.loads(lines[0])
83+
second_entry = json.loads(lines[1])
84+
85+
assert first_entry["step"] == 0
86+
assert first_entry["output_file"] == "outputs/step_0.out.txt"
87+
assert first_entry["output_length"] == len("first")
88+
assert second_entry["step"] == 1
89+
assert second_entry["output_file"] == "outputs/step_1.out.txt"
90+
assert second_entry["output_length"] == len("second")
91+
92+
93+
def test_root_directory_creation_is_idempotent(tmp_path):
94+
first = RunArtifacts(log_dir=str(tmp_path), run_id="abc-123")
95+
second = RunArtifacts(log_dir=str(tmp_path), run_id="abc-123")
96+
97+
assert first.root == second.root == (tmp_path / "abc-123")
98+
assert first.root.exists()
99+
100+
101+
def test_step_snapshot_sanitizes_path_traversal(artifacts, tmp_path):
102+
artifacts.save_step_code(step=0, file_map={"../../etc/evil.py": "malicious"})
103+
104+
assert not (tmp_path / "etc" / "evil.py").exists()
105+
assert (artifacts.root / "steps" / "0" / "files" / "etc" / "evil.py").exists()
106+
107+
108+
def test_best_snapshot_sanitizes_path_traversal(artifacts, tmp_path):
109+
artifacts.save_best_code({"../../../tmp/evil.py": "malicious"})
110+
111+
assert not (tmp_path.parent / "tmp" / "evil.py").exists()
112+
assert (artifacts.root / "best" / "files" / "tmp" / "evil.py").exists()

weco/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ def _recover_suggest_after_transport_error(
113113

114114
def start_optimization_run(
115115
console: Console,
116-
source_code: str,
117-
source_path: str,
116+
source_code: str | dict[str, str],
117+
source_path: str | None,
118118
evaluation_command: str,
119119
metric_name: str,
120120
maximize: bool,

weco/artifacts.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
"""On-disk artifact management for optimization runs.
2+
3+
Centralizes the directory layout and write logic for all artifacts
4+
produced during an optimization run under .runs/<run_id>/.
5+
"""
6+
7+
import json
8+
import pathlib
9+
from datetime import datetime
10+
11+
12+
def _sanitize_artifact_path(path_value: str) -> pathlib.Path:
13+
"""Convert a source path into a safe relative artifact path.
14+
15+
Strips traversal components (..), absolute prefixes, and Windows
16+
drive letters so that artifacts are always written under the
17+
intended directory.
18+
"""
19+
normalized = path_value.replace("\\", "/")
20+
parts = pathlib.PurePosixPath(normalized).parts
21+
safe_parts: list[str] = []
22+
for part in parts:
23+
if part in ("", ".", "/"):
24+
continue
25+
if part == "..":
26+
continue
27+
if not safe_parts and ":" in part:
28+
part = part.replace(":", "_")
29+
safe_parts.append(part)
30+
31+
if not safe_parts:
32+
return pathlib.Path("unnamed_file")
33+
return pathlib.Path(*safe_parts)
34+
35+
36+
class RunArtifacts:
37+
"""Manages the on-disk artifact layout for a single optimization run.
38+
39+
Layout::
40+
41+
<root>/
42+
steps/<step>/
43+
files/<relative_path> # actual code files
44+
manifest.json # machine-readable index
45+
best/
46+
files/<relative_path>
47+
manifest.json
48+
outputs/
49+
step_<n>.out.txt # execution stdout/stderr
50+
exec_output.jsonl # centralized output index
51+
"""
52+
53+
def __init__(self, log_dir: str, run_id: str) -> None:
54+
self.root = pathlib.Path(log_dir) / run_id
55+
self.root.mkdir(parents=True, exist_ok=True)
56+
57+
# ------------------------------------------------------------------
58+
# Code snapshots
59+
# ------------------------------------------------------------------
60+
61+
def save_step_code(self, step: int, file_map: dict[str, str]) -> pathlib.Path:
62+
"""Write code snapshot + manifest for a given step.
63+
64+
Returns the bundle directory path.
65+
"""
66+
return self._write_code_bundle(file_map, label=("steps", str(step)))
67+
68+
def save_best_code(self, file_map: dict[str, str]) -> pathlib.Path:
69+
"""Write code snapshot + manifest for the best result.
70+
71+
Returns the bundle directory path.
72+
"""
73+
return self._write_code_bundle(file_map, label=("best",))
74+
75+
# ------------------------------------------------------------------
76+
# Execution output
77+
# ------------------------------------------------------------------
78+
79+
def save_execution_output(self, step: int, output: str) -> None:
80+
"""Save execution output as a per-step file and append to the JSONL index."""
81+
timestamp = datetime.now().isoformat()
82+
83+
outputs_dir = self.root / "outputs"
84+
# Keep raw execution output per step for easy local inspection.
85+
outputs_dir.mkdir(parents=True, exist_ok=True)
86+
87+
step_file = outputs_dir / f"step_{step}.out.txt"
88+
# Store full stdout/stderr for this exact step.
89+
step_file.write_text(output, encoding="utf-8")
90+
91+
jsonl_file = self.root / "exec_output.jsonl"
92+
entry = {
93+
"step": step,
94+
"timestamp": timestamp,
95+
"output_file": step_file.relative_to(self.root).as_posix(),
96+
"output_length": len(output),
97+
}
98+
# Append compact metadata so tooling can stream/index outputs.
99+
with open(jsonl_file, "a", encoding="utf-8") as f:
100+
f.write(json.dumps(entry) + "\n")
101+
102+
# ------------------------------------------------------------------
103+
# Internal helpers
104+
# ------------------------------------------------------------------
105+
106+
def _write_code_bundle(self, file_map: dict[str, str], label: tuple[str, ...]) -> pathlib.Path:
107+
bundle_dir = self.root.joinpath(*label)
108+
files_dir = bundle_dir / "files"
109+
files_dir.mkdir(parents=True, exist_ok=True)
110+
111+
files_manifest: list[dict[str, str | int]] = []
112+
for source_path, content in sorted(file_map.items()):
113+
artifact_rel = _sanitize_artifact_path(source_path)
114+
artifact_path = files_dir / artifact_rel
115+
artifact_path.parent.mkdir(parents=True, exist_ok=True)
116+
artifact_path.write_text(content, encoding="utf-8")
117+
files_manifest.append(
118+
{"path": source_path, "artifact_path": artifact_rel.as_posix(), "bytes": len(content.encode("utf-8"))}
119+
)
120+
121+
is_step = label[0] == "steps"
122+
manifest: dict = {
123+
"type": "step_code_snapshot" if is_step else "best_code_snapshot",
124+
"created_at": datetime.now().isoformat(),
125+
"file_count": len(files_manifest),
126+
"files": files_manifest,
127+
}
128+
if is_step:
129+
manifest["step"] = int(label[1])
130+
131+
manifest_path = bundle_dir / "manifest.json"
132+
manifest_path.write_text(json.dumps(manifest, indent=2) + "\n", encoding="utf-8")
133+
return bundle_dir

weco/cli.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
RunStartAttemptedEvent,
1616
)
1717
from .utils import check_for_cli_updates, get_default_model, UnrecognizedAPIKeysError, DefaultModelNotFoundError
18-
from .validation import validate_source_file, validate_log_directory, ValidationError, print_validation_error
18+
from .validation import validate_sources, validate_log_directory, ValidationError, print_validation_error
1919

2020

2121
install(show_locals=True)
@@ -55,12 +55,15 @@ def parse_api_keys(api_key_args: list[str] | None) -> dict[str, str]:
5555
# Function to define and return the run_parser (or configure it on a passed subparser object)
5656
# This helps keep main() cleaner and centralizes run command arg definitions.
5757
def configure_run_parser(run_parser: argparse.ArgumentParser) -> None:
58-
run_parser.add_argument(
59-
"-s",
60-
"--source",
58+
source_group = run_parser.add_mutually_exclusive_group(required=True)
59+
source_group.add_argument(
60+
"-s", "--source", type=str, help="Path to a single source code file to be optimized (e.g., `optimize.py`)"
61+
)
62+
source_group.add_argument(
63+
"--sources",
64+
nargs="+",
6165
type=str,
62-
required=True,
63-
help="Path to the source code file that will be optimized (e.g., `optimize.py`)",
66+
help="Paths to multiple source code files to be optimized together (e.g., `model.py utils.py config.py`)",
6467
)
6568
run_parser.add_argument(
6669
"-c",
@@ -111,7 +114,7 @@ def configure_run_parser(run_parser: argparse.ArgumentParser) -> None:
111114
run_parser.add_argument(
112115
"--save-logs",
113116
action="store_true",
114-
help="Save execution output to .runs/<run-id>/outputs/step_<n>.out.txt with JSONL index",
117+
help="Save execution output to .runs/<run-id>/outputs/step_<n>.out.txt with JSONL index. Code snapshots are written to .runs/<run-id>/steps/<step>/files and .runs/<run-id>/best/files.",
115118
)
116119
run_parser.add_argument(
117120
"--apply-change",
@@ -263,9 +266,12 @@ def execute_run_command(args: argparse.Namespace) -> None:
263266

264267
ctx = get_event_context()
265268

269+
# Normalize source input so --source follows the same internal path as --sources
270+
source_arg = args.sources if args.sources is not None else [args.source]
271+
266272
# Early validation — fail fast with helpful errors
267273
try:
268-
validate_source_file(args.source)
274+
validate_sources(source_arg)
269275
validate_log_directory(args.log_dir)
270276
except ValidationError as e:
271277
print_validation_error(e, console)
@@ -301,7 +307,7 @@ def execute_run_command(args: argparse.Namespace) -> None:
301307
)
302308

303309
success = optimize(
304-
source=args.source,
310+
source=source_arg,
305311
eval_command=args.eval_command,
306312
metric=args.metric,
307313
goal=args.goal,

0 commit comments

Comments
 (0)