|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import json |
| 4 | +import time |
| 5 | +from dataclasses import dataclass, asdict |
| 6 | +from pathlib import Path |
| 7 | +from typing import List, Dict, Any, Tuple |
| 8 | + |
| 9 | +import yaml |
| 10 | + |
| 11 | +from .advisor import HighOrbitAdvisor |
| 12 | +from .config import configure_lm |
| 13 | + |
| 14 | + |
| 15 | +@dataclass |
| 16 | +class Scenario: |
| 17 | + id: str |
| 18 | + prompt: str |
| 19 | + playbook_path: str | None = None |
| 20 | + |
| 21 | + |
| 22 | +@dataclass |
| 23 | +class EvalRecord: |
| 24 | + scenario_id: str |
| 25 | + prompt: str |
| 26 | + timestamp: float |
| 27 | + latency_ms: float |
| 28 | + advice: str |
| 29 | + actions: List[str] |
| 30 | + metric_to_watch: str |
| 31 | + risks: List[str] |
| 32 | + critic_score: int |
| 33 | + critic_feedback: str |
| 34 | + format_ok: bool |
| 35 | + actions_count: int |
| 36 | + risks_count: int |
| 37 | + |
| 38 | + |
| 39 | +def load_scenarios(path: str | Path) -> List[Scenario]: |
| 40 | + p = Path(path) |
| 41 | + data = yaml.safe_load(p.read_text()) |
| 42 | + scenarios: List[Scenario] = [] |
| 43 | + for i, item in enumerate(data.get("scenarios", [])): |
| 44 | + scenarios.append( |
| 45 | + Scenario( |
| 46 | + id=item.get("id") or f"s{i+1}", |
| 47 | + prompt=item["prompt"], |
| 48 | + playbook_path=item.get("playbook"), |
| 49 | + ) |
| 50 | + ) |
| 51 | + return scenarios |
| 52 | + |
| 53 | + |
| 54 | +def _split_lines(value: Any) -> List[str]: |
| 55 | + if isinstance(value, list): |
| 56 | + return [str(x).strip() for x in value if str(x).strip()] |
| 57 | + if isinstance(value, str): |
| 58 | + return [ln.strip() for ln in value.split("\n") if ln.strip()] |
| 59 | + return [] |
| 60 | + |
| 61 | + |
| 62 | +def _format_eval(advice: str, actions_lines: List[str], risks_lines: List[str]) -> Tuple[bool, int, int]: |
| 63 | + # Clean bullet prefixes |
| 64 | + a_clean = [ln.lstrip("123456789. -•*").strip() for ln in actions_lines if ln.strip()] |
| 65 | + r_clean = [ln.lstrip("123456789. -•*").strip() for ln in risks_lines if ln.strip()] |
| 66 | + actions_count = len(a_clean) |
| 67 | + risks_count = len(r_clean) |
| 68 | + # Format rules: 3-5 actions, exactly 3 risks, advice non-empty |
| 69 | + format_ok = (3 <= actions_count <= 5) and (risks_count == 3) and bool(advice and advice.strip()) |
| 70 | + return format_ok, actions_count, risks_count |
| 71 | + |
| 72 | + |
| 73 | +def run_evals(scenarios: List[Scenario]) -> List[EvalRecord]: |
| 74 | + # Ensure LM is configured for online evals |
| 75 | + configure_lm() |
| 76 | + advisor = HighOrbitAdvisor() |
| 77 | + records: List[EvalRecord] = [] |
| 78 | + |
| 79 | + for sc in scenarios: |
| 80 | + playbook = "" |
| 81 | + if sc.playbook_path and Path(sc.playbook_path).exists(): |
| 82 | + playbook = Path(sc.playbook_path).read_text() |
| 83 | + |
| 84 | + history = [{"role": "user", "content": sc.prompt}] |
| 85 | + start = time.time() |
| 86 | + res = advisor(history=history, playbook=playbook) |
| 87 | + latency_ms = (time.time() - start) * 1000.0 |
| 88 | + |
| 89 | + actions_lines = _split_lines(res.actions_48h) |
| 90 | + risks_lines = _split_lines(res.risks) |
| 91 | + format_ok, a_count, r_count = _format_eval(res.advice or "", actions_lines, risks_lines) |
| 92 | + |
| 93 | + records.append( |
| 94 | + EvalRecord( |
| 95 | + scenario_id=sc.id, |
| 96 | + prompt=sc.prompt, |
| 97 | + timestamp=time.time(), |
| 98 | + latency_ms=latency_ms, |
| 99 | + advice=res.advice or "", |
| 100 | + actions=[ln.lstrip("123456789. -•*").strip() for ln in actions_lines], |
| 101 | + metric_to_watch=res.metric_to_watch or "", |
| 102 | + risks=[ln.lstrip("123456789. -•*").strip() for ln in risks_lines], |
| 103 | + critic_score=int(getattr(res, "score", 0) or 0), |
| 104 | + critic_feedback=str(getattr(res, "critique", "") or ""), |
| 105 | + format_ok=format_ok, |
| 106 | + actions_count=a_count, |
| 107 | + risks_count=r_count, |
| 108 | + ) |
| 109 | + ) |
| 110 | + |
| 111 | + return records |
| 112 | + |
| 113 | + |
| 114 | +def summarize_results(records: List[EvalRecord]) -> Dict[str, Any]: |
| 115 | + n = len(records) |
| 116 | + if n == 0: |
| 117 | + return {"count": 0} |
| 118 | + fmt_ok = sum(1 for r in records if r.format_ok) |
| 119 | + avg_score = sum(r.critic_score for r in records) / n |
| 120 | + avg_latency = sum(r.latency_ms for r in records) / n |
| 121 | + return { |
| 122 | + "count": n, |
| 123 | + "format_ok_rate": fmt_ok / n, |
| 124 | + "avg_critic_score": avg_score, |
| 125 | + "avg_latency_ms": avg_latency, |
| 126 | + } |
| 127 | + |
| 128 | + |
| 129 | +def save_eval_results(records: List[EvalRecord], out_path: str | Path) -> None: |
| 130 | + p = Path(out_path) |
| 131 | + p.parent.mkdir(parents=True, exist_ok=True) |
| 132 | + with p.open("w") as f: |
| 133 | + for r in records: |
| 134 | + f.write(json.dumps(asdict(r)) + "\n") |
| 135 | + |
0 commit comments