Skip to content

Commit 15709fc

Browse files
committed
feat: persona-aware evals; export CSV/MD; provider-selectable model listing; critic LM + overlap-penalized rerank
1 parent a222c24 commit 15709fc

File tree

4 files changed

+244
-44
lines changed

4 files changed

+244
-44
lines changed

orbit_agent/advisor.py

Lines changed: 73 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ def forward(
268268

269269
from .tools.retention import calculate_cohort_retention
270270
from .tools.funnel import analyze_funnel, FunnelStep
271+
from .tools.finance import runway_months, expected_value
271272

272273
# Extract inline JSON-ish snippet if present
273274
m = re.search(r"(\[.*\]|\{.*\})", history_str, re.DOTALL)
@@ -289,6 +290,32 @@ def forward(
289290
res = analyze_funnel(steps)
290291
rates = ", ".join(f"{r:.1f}%" for r in res.conversion_rates)
291292
tool_results = f"Funnel steps={len(steps)}; Conversion: {rates}"
293+
elif isinstance(data, dict) and {"cash", "burn"} <= set(
294+
data.keys()
295+
):
296+
rr = runway_months(
297+
float(data["cash"]),
298+
float(data["burn"]),
299+
float(data.get("growth", 0)),
300+
)
301+
tool_results = f"Runway≈{rr.months:.1f}m; Alive={'Yes' if rr.default_alive else 'No'}"
302+
elif isinstance(data, dict) and {
303+
"p_upside",
304+
"ev_upside",
305+
"p_mid",
306+
"ev_mid",
307+
"p_down",
308+
"ev_down",
309+
} <= set(data.keys()):
310+
ev = expected_value(
311+
float(data["p_upside"]),
312+
float(data["ev_upside"]),
313+
float(data["p_mid"]),
314+
float(data["ev_mid"]),
315+
float(data["p_down"]),
316+
float(data["ev_down"]),
317+
)
318+
tool_results = f"EV≈${ev:,.0f} across scenarios"
292319
except Exception:
293320
tool_results = tool_results or ""
294321

@@ -297,8 +324,39 @@ def forward(
297324

298325
cfg = get_config()
299326
best_of_n = max(1, int(getattr(cfg, "best_of_n", 1)))
327+
overlap_alpha = float(getattr(cfg, "overlap_alpha", 2.0) or 2.0)
300328
best_payload = None
301-
best_score = -1
329+
best_score = -1e9
330+
331+
# Optional separate critic LM
332+
critic_lm = None
333+
try:
334+
import dspy as _dspy
335+
336+
if getattr(cfg, "critic_model", None):
337+
critic_lm = _dspy.LM(model=cfg.critic_model)
338+
except Exception:
339+
critic_lm = None
340+
341+
def _ngram_set(text: str, n: int = 3) -> set[str]:
342+
tokens = [t for t in text.lower().split() if t]
343+
return set(
344+
[
345+
" ".join(tokens[i : i + n])
346+
for i in range(0, max(0, len(tokens) - n + 1))
347+
]
348+
)
349+
350+
def _overlap_ratio(a: str, b: str, n: int = 3) -> float:
351+
if not a or not b:
352+
return 0.0
353+
A = _ngram_set(a, n)
354+
B = _ngram_set(b, n)
355+
if not A or not B:
356+
return 0.0
357+
inter = len(A & B)
358+
union = len(A | B)
359+
return inter / union
302360

303361
for _ in range(best_of_n):
304362
logger.info("Generating advice with LLM")
@@ -310,24 +368,29 @@ def forward(
310368
tool_results=tool_results or "No tools used in this session",
311369
)
312370

313-
# Critique with retry
371+
# Critique with retry (optionally on separate critic LM)
314372
logger.info("Getting critique")
315-
critique = self._call_llm_with_retry(
316-
self.critic,
317-
advice=self._clean_output(draft.advice),
318-
context=context_with_history,
319-
)
373+
advice_clean = self._clean_output(draft.advice)
374+
kwargs = dict(advice=advice_clean, context=context_with_history)
375+
if critic_lm is not None:
376+
critique = self._call_llm_with_retry(
377+
self.critic, **kwargs, lm=critic_lm
378+
)
379+
else:
380+
critique = self._call_llm_with_retry(self.critic, **kwargs)
320381

321-
score = int(getattr(critique, "score", 0) or 0)
382+
score_raw = int(getattr(critique, "score", 0) or 0)
383+
overlap = _overlap_ratio(advice_clean, playbook)
384+
score = score_raw - overlap_alpha * overlap
322385
if score > best_score:
323386
best_score = score
324387
best_payload = (
325-
self._clean_output(draft.advice),
388+
advice_clean,
326389
self._clean_output(draft.actions_48h),
327390
self._clean_output(draft.metric_to_watch),
328391
self._clean_output(draft.risks),
329392
critique.feedback,
330-
score,
393+
score_raw,
331394
)
332395

333396
advice, actions_48h, metric_to_watch, risks, feedback, score = best_payload

orbit_agent/cli.py

Lines changed: 87 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
grade_with_rubric,
2323
summarize_grades,
2424
save_grades,
25+
load_eval_records,
26+
export_summary_csv,
27+
export_summary_md,
2528
)
2629
import os
2730
import subprocess
@@ -606,42 +609,93 @@ def eval_grade(
606609
raise typer.Exit(1)
607610

608611

609-
@models_app.command("list")
610-
def models_list():
611-
"""List models from the active provider (OpenAI only for now)."""
612+
@eval_app.command("summary")
613+
def eval_summary(
614+
input_path: str = typer.Option(
615+
".orbit/evals/latest.jsonl", help="Path to JSONL results"
616+
),
617+
csv_out: str = typer.Option(None, help="Write scenario summary to CSV here"),
618+
md_out: str = typer.Option(None, help="Write scenario summary to Markdown here"),
619+
):
620+
"""Export per-scenario summaries to CSV/Markdown."""
612621
try:
613-
import os
614-
from openai import OpenAI
622+
recs = load_eval_records(input_path)
623+
if not recs:
624+
console.print("[yellow]No records to summarize[/yellow]")
625+
raise typer.Exit(0)
626+
if csv_out:
627+
export_summary_csv(recs, csv_out)
628+
console.print(f"[green]CSV written:[/green] {csv_out}")
629+
if md_out:
630+
export_summary_md(recs, md_out)
631+
console.print(f"[green]Markdown written:[/green] {md_out}")
632+
if not csv_out and not md_out:
633+
console.print(
634+
"[yellow]No output paths specified (use --csv-out or --md-out)[/yellow]"
635+
)
636+
except Exception as e:
637+
logger.error(f"Eval summary failed: {e}")
638+
console.print(f"[bold red]Error:[/bold red] {e}")
639+
raise typer.Exit(1)
615640

616-
key = os.getenv("OPENAI_API_KEY")
617-
if not key:
618-
console.print("[red]OPENAI_API_KEY not set[/red]")
619-
raise typer.Exit(1)
620641

621-
client = OpenAI(api_key=key)
622-
resp = client.models.list()
623-
ids = [m.id for m in resp.data]
624-
# Prefer relevant chat-capable frontier models
625-
preferred = [
626-
i
627-
for i in ids
628-
if any(
629-
i.startswith(p)
630-
for p in (
631-
"gpt-5",
632-
"gpt-4.1",
633-
"gpt-4o",
634-
"o3",
635-
)
636-
)
637-
]
638-
console.print("[bold]Candidate Models[/]:")
639-
for mid in sorted(preferred):
640-
console.print(f"- {mid}")
641-
others = [i for i in ids if i not in preferred]
642-
console.print("\n[dim]Other models (truncated)[/dim]")
643-
for mid in sorted(others)[:20]:
644-
console.print(f"- {mid}")
642+
@models_app.command("list")
643+
def models_list(
644+
provider: str = typer.Option("openai", help="Provider: openai|anthropic")
645+
):
646+
"""List models from a provider (token-free)."""
647+
try:
648+
provider = provider.lower()
649+
if provider == "openai":
650+
import os
651+
from openai import OpenAI
652+
653+
key = os.getenv("OPENAI_API_KEY")
654+
if not key:
655+
console.print("[red]OPENAI_API_KEY not set[/red]")
656+
raise typer.Exit(1)
657+
client = OpenAI(api_key=key)
658+
resp = client.models.list()
659+
ids = [m.id for m in resp.data]
660+
preferred = [
661+
i
662+
for i in ids
663+
if any(i.startswith(p) for p in ("gpt-5", "gpt-4.1", "gpt-4o", "o3"))
664+
]
665+
console.print("[bold]OpenAI Candidate Models[/]:")
666+
for mid in sorted(preferred):
667+
console.print(f"- {mid}")
668+
others = [i for i in ids if i not in preferred]
669+
console.print("\n[dim]Other models (truncated)[/dim]")
670+
for mid in sorted(others)[:20]:
671+
console.print(f"- {mid}")
672+
elif provider == "anthropic":
673+
try:
674+
import os
675+
import anthropic
676+
677+
key = os.getenv("ANTHROPIC_API_KEY")
678+
if not key:
679+
console.print("[red]ANTHROPIC_API_KEY not set[/red]")
680+
raise typer.Exit(1)
681+
client = anthropic.Anthropic(api_key=key)
682+
# Anthropic SDK provides a fixed set; list known public IDs if API lacks listing.
683+
known = [
684+
"claude-3-5-sonnet-20241022",
685+
"claude-3-5-haiku-20241022",
686+
"claude-3-opus-20240229",
687+
"claude-3-sonnet-20240229",
688+
"claude-3-haiku-20240307",
689+
]
690+
console.print("[bold]Anthropic Models (known set)[/]:")
691+
for mid in known:
692+
console.print(f"- {mid}")
693+
except Exception as e:
694+
console.print(f"[red]Anthropic listing not available:[/red] {e}")
695+
raise typer.Exit(1)
696+
else:
697+
console.print("[red]Unsupported provider[/red]")
698+
raise typer.Exit(1)
645699
except Exception as e:
646700
logger.error(f"Model listing failed: {e}")
647701
console.print(f"[bold red]Error:[/bold red] {e}")

orbit_agent/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ class AppConfig:
7777
# Generation quality controls
7878
best_of_n: int = 1
7979
temperature: float = 0.7
80+
critic_model: Optional[str] = None
81+
overlap_alpha: float = 2.0
8082

8183
def __post_init__(self):
8284
"""Validate configuration after initialization"""
@@ -184,6 +186,8 @@ def load_config() -> AppConfig:
184186
cost_per_1k_completion=float(os.getenv("ORBIT_COST_PER_1K_COMPLETION", "0")),
185187
best_of_n=int(os.getenv("ORBIT_BEST_OF_N", "1")),
186188
temperature=float(os.getenv("ORBIT_TEMPERATURE", "0.7")),
189+
critic_model=os.getenv("ORBIT_CRITIC_LM"),
190+
overlap_alpha=float(os.getenv("ORBIT_OVERLAP_ALPHA", "2.0")),
187191
)
188192

189193
return config

orbit_agent/evals.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,18 @@ def run_evals(scenarios: List[Scenario]) -> List[EvalRecord]:
116116
playbook = Path(sc.playbook_path).read_text()
117117

118118
history = [{"role": "user", "content": sc.prompt}]
119+
persona_ctx = []
120+
if sc.persona:
121+
persona_ctx.append(f"Persona: {sc.persona}")
122+
if sc.stage:
123+
persona_ctx.append(f"Stage: {sc.stage}")
124+
if sc.rubric:
125+
persona_ctx.append(
126+
"Success Rubric:" + "\n" + "\n".join(f"- {r}" for r in sc.rubric)
127+
)
128+
ctx = "\n".join(persona_ctx) if persona_ctx else None
119129
start = time.time()
120-
res = advisor(history=history, playbook=playbook)
130+
res = advisor(history=history, playbook=playbook, context=ctx)
121131
latency_ms = (time.time() - start) * 1000.0
122132

123133
actions_lines = _split_lines(res.actions_48h)
@@ -250,3 +260,72 @@ def save_grades(graded: List[Dict[str, Any]], out_path: str | Path) -> None:
250260
with p.open("w") as f:
251261
for g in graded:
252262
f.write(json.dumps(g) + "\n")
263+
264+
265+
def load_eval_records(path: str | Path) -> List[EvalRecord]:
266+
p = Path(path)
267+
recs: List[EvalRecord] = []
268+
if not p.exists():
269+
return recs
270+
for line in p.read_text().splitlines():
271+
if not line.strip():
272+
continue
273+
recs.append(EvalRecord(**json.loads(line)))
274+
return recs
275+
276+
277+
def summarize_by_scenario(records: List[EvalRecord]) -> List[Dict[str, Any]]:
278+
from collections import defaultdict
279+
280+
groups = defaultdict(list)
281+
for r in records:
282+
groups[r.scenario_id].append(r)
283+
rows = []
284+
for sid, items in groups.items():
285+
n = len(items)
286+
avg_score = sum(i.critic_score for i in items) / n
287+
avg_overlap = sum((i.overlap_ratio or 0.0) for i in items) / n
288+
avg_latency = sum(i.latency_ms for i in items) / n
289+
rows.append(
290+
{
291+
"scenario_id": sid,
292+
"count": n,
293+
"avg_critic_score": avg_score,
294+
"avg_overlap": avg_overlap,
295+
"avg_latency_ms": avg_latency,
296+
}
297+
)
298+
return rows
299+
300+
301+
def export_summary_csv(records: List[EvalRecord], out_path: str | Path) -> None:
302+
import csv
303+
304+
rows = summarize_by_scenario(records)
305+
with open(out_path, "w", newline="") as f:
306+
writer = csv.DictWriter(
307+
f,
308+
fieldnames=[
309+
"scenario_id",
310+
"count",
311+
"avg_critic_score",
312+
"avg_overlap",
313+
"avg_latency_ms",
314+
],
315+
)
316+
writer.writeheader()
317+
for row in rows:
318+
writer.writerow(row)
319+
320+
321+
def export_summary_md(records: List[EvalRecord], out_path: str | Path) -> None:
322+
rows = summarize_by_scenario(records)
323+
lines = [
324+
"| Scenario | Count | Avg Score | Overlap | Latency (ms) |",
325+
"|---|---:|---:|---:|---:|",
326+
]
327+
for r in rows:
328+
lines.append(
329+
f"| {r['scenario_id']} | {r['count']} | {r['avg_critic_score']:.2f} | {r['avg_overlap']:.2f} | {r['avg_latency_ms']:.0f} |"
330+
)
331+
Path(out_path).write_text("\n".join(lines) + "\n")

0 commit comments

Comments
 (0)