Skip to content

Commit 572460b

Browse files
committed
test(evals): add 13 more math/time tasks; smoke-tested with mock provider
1 parent d0ce814 commit 572460b

File tree

4 files changed

+128
-9
lines changed

4 files changed

+128
-9
lines changed

evals/tasks.yaml

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,40 @@
1616

1717
- question: "What's 9! / (3!*3!*3!)? Just the integer."
1818
expect_contains: "1680"
19+
20+
# Additional math/time tasks
21+
- question: "What's 2^10? Return only the number."
22+
expect_contains: "1024"
23+
24+
- question: "Compute 7/2 and return a decimal."
25+
expect_contains: "3.5"
26+
27+
- question: "Compute 7//2 (floor division). Return only the integer."
28+
expect_contains: "3"
29+
30+
- question: "What is 100 % 7? Return only the integer."
31+
expect_contains: "2"
32+
33+
- question: "Calculate 6! and return only the number."
34+
expect_contains: "720"
35+
36+
- question: "Compute (12.5 * 4). Return only the number."
37+
expect_contains: "50"
38+
39+
- question: "Compute ((7+3)*5 - 12)/4 and return a decimal."
40+
expect_contains: "9.5"
41+
42+
- question: "Compute (3+5)*2^3 and return only the number."
43+
expect_contains: "64"
44+
45+
- question: "What is (12 - 30)? Return only the number."
46+
expect_contains: "-18"
47+
48+
- question: "Add 100 and 250, then tell me the current date in UTC."
49+
expect_contains: "350"
50+
51+
- question: "Give me the current timestamp in UTC (ISO)."
52+
expect_key: "iso"
53+
54+
- question: "Tell me today's date (UTC) in ISO format."
55+
expect_key: "iso"

micro_agent/agent.py

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .signatures import PlanOrAct, Finalize, PlanWithTools
66
from .tools import TOOLS, run_tool, safe_eval_math, to_dspy_tools
77
from .runtime import parse_decision_text
8+
from .costs import estimate_tokens, estimate_cost_usd
89

910
class MicroAgent(dspy.Module):
1011
"""
@@ -151,7 +152,7 @@ def used_tool(state, name: str) -> bool:
151152

152153
state: List[Dict[str, Any]] = []
153154

154-
def _accumulate_usage():
155+
def _accumulate_usage(input_text: str = "", output_text: str = ""):
155156
# Pull new usage entries from dspy.settings.trace
156157
try:
157158
for _, _, out in dspy.settings.trace[-1:]:
@@ -164,6 +165,22 @@ def _accumulate_usage():
164165
total_out_tokens += int(usage.get("output_tokens", 0) or 0)
165166
except Exception:
166167
pass
168+
# Heuristic fallback: estimate tokens from input/output texts and compute cost via env prices
169+
try:
170+
if input_text:
171+
it = estimate_tokens(input_text, getattr(self.lm, "model", ""))
172+
else:
173+
it = 0
174+
if output_text:
175+
ot = estimate_tokens(output_text, getattr(self.lm, "model", ""))
176+
else:
177+
ot = 0
178+
if it or ot:
179+
total_in_tokens += it
180+
total_out_tokens += ot
181+
total_cost += estimate_cost_usd(it, ot, getattr(self.lm, "model", ""), self._provider or "")
182+
except Exception:
183+
pass
167184

168185
# Path A: OpenAI-native tool calling using DSPy signatures/adapters.
169186
if self._use_tool_calls:
@@ -184,6 +201,17 @@ def _accumulate_usage():
184201
total_out_tokens += int(usage.get('output_tokens', 0) or 0)
185202
except Exception:
186203
pass
204+
# Heuristic fallback: estimate using a reconstructed prompt & result
205+
try:
206+
approx_prompt = self._decision_prompt(
207+
question=question,
208+
state_json=json.dumps(state, ensure_ascii=False),
209+
tools_json=json.dumps(self._tool_list, ensure_ascii=False),
210+
)
211+
approx_out = getattr(pred, 'final', None) or (str(getattr(pred, 'tool_calls', '')))
212+
_accumulate_usage(approx_prompt, approx_out)
213+
except Exception:
214+
pass
187215

188216
# If tool calls are proposed, execute them.
189217
calls = getattr(pred, 'tool_calls', None)
@@ -302,6 +330,11 @@ def _accumulate_usage():
302330
# Path B: Ollama-friendly loop via raw LM completions and robust JSON parsing.
303331
for _ in range(self.max_steps):
304332
lm_calls += 1
333+
prompt_text = self._decision_prompt(
334+
question=question,
335+
state_json=json.dumps(state, ensure_ascii=False),
336+
tools_json=json.dumps(self._tool_list, ensure_ascii=False),
337+
)
305338
raw = self.lm(
306339
prompt=self._decision_prompt(
307340
question=question,
@@ -310,13 +343,18 @@ def _accumulate_usage():
310343
)
311344
)
312345
decision_text = raw[0] if isinstance(raw, list) else (raw if isinstance(raw, str) else str(raw))
313-
_accumulate_usage()
346+
_accumulate_usage(prompt_text, decision_text)
314347

315348
# Extract and parse JSON; if malformed, try a flexible parser and one self-correction retry.
316349
try:
317350
decision = parse_decision_text(decision_text)
318351
except Exception:
319352
lm_calls += 1
353+
prompt_text = self._decision_prompt(
354+
question=question,
355+
state_json=json.dumps(state, ensure_ascii=False),
356+
tools_json=json.dumps(self._tool_list, ensure_ascii=False),
357+
)
320358
raw = self.lm(
321359
prompt=self._decision_prompt(
322360
question=question,
@@ -325,7 +363,7 @@ def _accumulate_usage():
325363
)
326364
)
327365
decision_text = raw[0] if isinstance(raw, list) else (raw if isinstance(raw, str) else str(raw))
328-
_accumulate_usage()
366+
_accumulate_usage(prompt_text, decision_text)
329367
try:
330368
decision = parse_decision_text(decision_text)
331369
except Exception:
@@ -413,14 +451,18 @@ def _accumulate_usage():
413451
ans = " | ".join(parts) if parts else ""
414452
else:
415453
lm_calls += 1
454+
finalize_prompt = (
455+
"Given the question and the trace of tool observations, write the final answer.\n\n"
456+
f"Question: {question}\n\nTrace: {json.dumps(state, ensure_ascii=False)}\n\n"
457+
"Answer succinctly."
458+
)
416459
raw = self.lm(
417460
prompt=(
418-
"Given the question and the trace of tool observations, write the final answer.\n\n"
419-
f"Question: {question}\n\nTrace: {json.dumps(state, ensure_ascii=False)}\n\n"
420-
"Answer succinctly."
461+
finalize_prompt
421462
)
422463
)
423464
ans = raw[0] if isinstance(raw, list) else (raw if isinstance(raw, str) else str(raw))
465+
_accumulate_usage(finalize_prompt, ans)
424466
p = dspy.Prediction(answer=ans, trace=state)
425467
p.usage = {
426468
"lm_calls": lm_calls,

micro_agent/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def configure_lm():
3131
def _try(name, fn):
3232
try:
3333
lm = fn()
34-
dspy.settings.configure(lm=lm)
34+
dspy.settings.configure(lm=lm, track_usage=True)
3535
return True
3636
except Exception as e:
3737
tried.append((name, repr(e)))
@@ -79,9 +79,9 @@ def __call__(self, *, prompt: str, **kwargs):
7979

8080
# Allow explicit mock via env
8181
if provider == "mock":
82-
dspy.settings.configure(lm=_MockLM())
82+
dspy.settings.configure(lm=_MockLM(), track_usage=True)
8383
return
8484

8585
# If we got here, all backends failed: use mock and include details in a warning
86-
dspy.settings.configure(lm=_MockLM())
86+
dspy.settings.configure(lm=_MockLM(), track_usage=True)
8787
return

micro_agent/costs.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from __future__ import annotations
2+
import os
3+
from typing import Tuple
4+
5+
def _try_tiktoken(model: str):
6+
try:
7+
import tiktoken
8+
# Use a generic encoding if specific not found
9+
try:
10+
enc = tiktoken.encoding_for_model(model)
11+
except Exception:
12+
enc = tiktoken.get_encoding("o200k_base")
13+
return enc
14+
except Exception:
15+
return None
16+
17+
def estimate_tokens(text: str, model: str = "gpt-4o-mini") -> int:
18+
if not text:
19+
return 0
20+
enc = _try_tiktoken(model)
21+
if enc is None:
22+
# Fallback heuristic: ~4 chars per token
23+
return max(1, len(text) // 4)
24+
try:
25+
return len(enc.encode(text))
26+
except Exception:
27+
return max(1, len(text) // 4)
28+
29+
def get_prices_per_1k(model: str, provider: str) -> Tuple[float, float]:
30+
# Allow env overrides; default to 0 to avoid misleading values.
31+
in_price = float(os.getenv("OPENAI_INPUT_PRICE_PER_1K", "0") or 0)
32+
out_price = float(os.getenv("OPENAI_OUTPUT_PRICE_PER_1K", "0") or 0)
33+
if provider != "openai":
34+
return 0.0, 0.0
35+
return in_price, out_price
36+
37+
def estimate_cost_usd(input_tokens: int, output_tokens: int, model: str, provider: str) -> float:
38+
in_price_1k, out_price_1k = get_prices_per_1k(model, provider)
39+
return (input_tokens / 1000.0) * in_price_1k + (output_tokens / 1000.0) * out_price_1k
40+

0 commit comments

Comments
 (0)