|
4 | 4 | from micro_agent.config import configure_lm |
5 | 5 | from micro_agent.agent import MicroAgent |
6 | 6 | from micro_agent.runtime import new_trace_id, dump_trace |
| 7 | +from micro_agent.costs import estimate_tokens, estimate_cost_usd |
7 | 8 |
|
8 | 9 | def load_yaml(path: str): |
9 | 10 | with open(path, "r", encoding="utf-8") as f: |
@@ -70,10 +71,21 @@ def main(): |
70 | 71 | latencies.append(dt) |
71 | 72 | # Basic usage tracking (provided by MicroAgent) |
72 | 73 | usage = getattr(pred, "usage", {}) or {} |
73 | | - lm_calls_list.append(int(usage.get("lm_calls", 0) or 0)) |
| 74 | + lm_calls = int(usage.get("lm_calls", 0) or 0) |
| 75 | + lm_calls_list.append(lm_calls) |
74 | 76 | tool_calls_list.append(int(usage.get("tool_calls", 0) or 0)) |
75 | 77 | steps_list.append(len(pred.trace or [])) |
76 | | - costs_list.append(float(usage.get("cost", 0.0) or 0.0)) |
| 78 | + |
| 79 | + # Approximate cost (tokens) per run using simple heuristics |
| 80 | + provider = usage.get("provider") or "openai" |
| 81 | + model = usage.get("model") or "gpt-4o-mini" |
| 82 | + q_text = str(q) |
| 83 | + trace_text = json.dumps(pred.trace, ensure_ascii=False) |
| 84 | + ans_text = str(pred.answer or "") |
| 85 | + # Rough input tokens ~ (lm_calls * question) + final trace |
| 86 | + in_tokens = lm_calls * estimate_tokens(q_text, model=model) + estimate_tokens(trace_text, model=model) |
| 87 | + out_tokens = estimate_tokens(ans_text, model=model) |
| 88 | + costs_list.append(estimate_cost_usd(in_tokens, out_tokens, model=model, provider=provider)) |
77 | 89 |
|
78 | 90 | print(f"[{i}/{len(dataset)}] s={s} t={dt:.2f}s q={q!r}") |
79 | 91 |
|
|
0 commit comments