Skip to content

Commit 5237c65

Browse files
committed
feat(usage): include usage and cost estimates in CLI output and saved traces; API /ask returns usage and cost_usd; add estimate_prediction_cost helper
1 parent 6dddd3d commit 5237c65

File tree

5 files changed

+46
-8
lines changed

5 files changed

+46
-8
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ micro-agent replay --path traces/<id>.jsonl --index -1
6868
## HTTP API
6969
- Start: `uvicorn micro_agent.server:app --reload --port 8000`
7070
- Endpoint: `POST /ask`
71-
- Request JSON: `{ "question": "...", "max_steps": 6 }`
72-
- Response JSON: `{ "answer": str, "trace_id": str, "trace_path": str, "steps": [...] }`
71+
- Request JSON: `{ "question": "...", "max_steps": 6, "use_tool_calls": bool? }`
72+
- Response JSON: `{ "answer": str, "trace_id": str, "trace_path": str, "steps": [...], "usage": {...}, "cost_usd": number }`
7373

7474
Example:
7575
```bash

micro_agent/cli.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .config import configure_lm
66
from .agent import MicroAgent
77
from .runtime import dump_trace, new_trace_id
8+
from .costs import estimate_prediction_cost
89

910
console = Console()
1011

@@ -63,8 +64,12 @@ def main():
6364

6465
pred = agent(q)
6566
trace_id = new_trace_id()
66-
path = dump_trace(trace_id, q, pred.trace, pred.answer)
67+
usage = getattr(pred, "usage", {}) or {}
68+
est = estimate_prediction_cost(q, pred.trace, pred.answer, usage)
69+
path = dump_trace(trace_id, q, pred.trace, pred.answer, usage=usage, cost_usd=est.get("cost_usd"))
6770

6871
console.print(Panel.fit(pred.answer, title="ANSWER"))
6972
console.print()
73+
console.print(Panel.fit(json.dumps({"usage": usage, "estimates": est}, indent=2, ensure_ascii=False), title="USAGE / ESTIMATES"))
74+
console.print()
7075
console.print(Panel.fit(json.dumps(pred.trace, indent=2, ensure_ascii=False), title=f"TRACE (saved: {path})"))

micro_agent/costs.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22
import os
3-
from typing import Tuple
3+
from typing import Tuple, Any, Dict
44

55
def _try_tiktoken(model: str):
66
try:
@@ -63,3 +63,25 @@ def get_prices_per_1k(model: str, provider: str) -> Tuple[float, float]:
6363
def estimate_cost_usd(input_tokens: int, output_tokens: int, model: str, provider: str) -> float:
6464
in_price_1k, out_price_1k = get_prices_per_1k(model, provider)
6565
return (input_tokens / 1000.0) * in_price_1k + (output_tokens / 1000.0) * out_price_1k
66+
67+
def estimate_prediction_cost(question: str, trace: Any, answer: str, usage: Dict[str, Any]) -> Dict[str, Any]:
68+
"""Estimate token usage and USD cost for a single prediction.
69+
70+
Heuristic: input tokens ~= lm_calls * tokens(question) + tokens(str(trace))
71+
output tokens ~= tokens(answer)
72+
"""
73+
provider = (usage or {}).get("provider") or "openai"
74+
model = (usage or {}).get("model") or "gpt-4o-mini"
75+
lm_calls = int((usage or {}).get("lm_calls", 0) or 0)
76+
77+
q_tokens = estimate_tokens(str(question or ""), model)
78+
trace_tokens = estimate_tokens(str(trace or ""), model)
79+
ans_tokens = estimate_tokens(str(answer or ""), model)
80+
in_tokens = lm_calls * q_tokens + trace_tokens
81+
out_tokens = ans_tokens
82+
cost = estimate_cost_usd(in_tokens, out_tokens, model=model, provider=provider)
83+
return {
84+
"input_tokens": in_tokens,
85+
"output_tokens": out_tokens,
86+
"cost_usd": cost,
87+
}

micro_agent/runtime.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22
import json, os, re, time, uuid, datetime
3-
from typing import Any, Dict, List, Optional, TypedDict
3+
from typing import Any, Dict, List, Optional, TypedDict, NotRequired
44
import ast
55
try:
66
import json_repair
@@ -21,18 +21,24 @@ class TraceRecord(TypedDict):
2121
question: str
2222
steps: List[Step]
2323
answer: str
24+
usage: NotRequired[Dict[str, Any]]
25+
cost_usd: NotRequired[float]
2426

2527
def new_trace_id() -> str:
2628
return uuid.uuid4().hex
2729

28-
def dump_trace(trace_id: str, question: str, steps: List[Step], answer: str) -> str:
30+
def dump_trace(trace_id: str, question: str, steps: List[Step], answer: str, *, usage: Optional[Dict[str, Any]] = None, cost_usd: Optional[float] = None) -> str:
2931
rec: TraceRecord = {
3032
"id": trace_id,
3133
"ts": datetime.datetime.now().isoformat(timespec="seconds"),
3234
"question": question,
3335
"steps": steps,
3436
"answer": answer,
3537
}
38+
if usage is not None:
39+
rec["usage"] = usage
40+
if cost_usd is not None:
41+
rec["cost_usd"] = float(cost_usd)
3642
path = os.path.join(TRACES_DIR, f"{trace_id}.jsonl")
3743
with open(path, "a", encoding="utf-8") as f:
3844
f.write(json.dumps(rec, ensure_ascii=False) + "\n")

micro_agent/server.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from fastapi.middleware.cors import CORSMiddleware
44
import os, json
55
from pydantic import BaseModel
6+
from .costs import estimate_prediction_cost
67
from importlib.metadata import version as _pkg_version, PackageNotFoundError
78
from .config import configure_lm
89
from .agent import MicroAgent
@@ -27,6 +28,8 @@ class AskResponse(BaseModel):
2728
trace_id: str
2829
trace_path: str
2930
steps: list
31+
usage: dict | None = None
32+
cost_usd: float | None = None
3033

3134
configure_lm()
3235
_agent = MicroAgent()
@@ -50,8 +53,10 @@ def ask(req: AskRequest):
5053
agent = _agent if req.use_tool_calls is None and req.max_steps == _agent.max_steps else MicroAgent(max_steps=req.max_steps, use_tool_calls=req.use_tool_calls)
5154
pred = agent(req.question)
5255
trace_id = new_trace_id()
53-
path = dump_trace(trace_id, req.question, pred.trace, pred.answer)
54-
return AskResponse(answer=pred.answer, trace_id=trace_id, trace_path=path, steps=pred.trace)
56+
usage = getattr(pred, "usage", {}) or {}
57+
est = estimate_prediction_cost(req.question, pred.trace, pred.answer, usage)
58+
path = dump_trace(trace_id, req.question, pred.trace, pred.answer, usage=usage, cost_usd=est.get("cost_usd"))
59+
return AskResponse(answer=pred.answer, trace_id=trace_id, trace_path=path, steps=pred.trace, usage=usage, cost_usd=est.get("cost_usd"))
5560

5661
@app.get("/trace/{trace_id}")
5762
def get_trace(trace_id: str):

0 commit comments

Comments
 (0)