Skip to content

Commit f343fb1

Browse files
committed
Harden math/time detection and tracing
1 parent 023d220 commit f343fb1

File tree

5 files changed

+165
-58
lines changed

5 files changed

+165
-58
lines changed

micro_agent/agent.py

Lines changed: 88 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -156,17 +156,24 @@ def forward(self, question: str):
156156

157157
def needs_math(q: str) -> bool:
158158
ql = q.lower()
159-
if re.search(r"[0-9].*[+\-*/]", q):
159+
if re.search(r"[0-9].*[+\-*/%]", q):
160160
return True
161-
if any(w in ql for w in ["add", "sum", "multiply", "divide", "compute", "calculate", "total", "power", "factorial", "!", "**", "^"]):
161+
if re.search(r"\b\d+(?:\.\d+)?\s*(?:x|times|multiplied by)\s*\d+(?:\.\d+)?\b", ql):
162+
return True
163+
if re.search(r"\b\d+(?:\.\d+)?\s*(?:plus|minus|add|added to|subtract|subtracted by|divide|divided by|over)\s*\d+(?:\.\d+)?\b", ql):
164+
return True
165+
if re.search(r"\d", ql) and any(w in ql for w in [
166+
"add", "sum", "plus", "minus", "subtract", "multiply", "divide",
167+
"total", "power", "factorial", "compute", "calculate"
168+
]):
162169
return True
163170
return False
164171

165172
def needs_time(q: str) -> bool:
166173
ql = q.lower()
167174
if "current time" in ql or "current date" in ql:
168175
return True
169-
return re.search(r"\b(time|times|date|dates|utc|now|today|tomorrow|yesterday|timestamp|datetime)\b", ql) is not None
176+
return re.search(r"\b(time|date|utc|now|today|tomorrow|yesterday|timestamp|datetime)\b", ql) is not None
170177

171178
def used_tool(state, name: str) -> bool:
172179
return any(step.get("tool") == name for step in state)
@@ -178,17 +185,25 @@ def used_tool(state, name: str) -> bool:
178185

179186
def _accumulate_usage(input_text: str = "", output_text: str = ""):
180187
# Pull new usage entries from dspy.settings.trace
188+
in_tok = 0
189+
out_tok = 0
190+
cost = 0.0
181191
try:
182192
for _, _, out in dspy.settings.trace[-1:]:
183193
usage = getattr(out, "usage", None) or {}
184194
nonlocal total_cost, total_in_tokens, total_out_tokens
185195
c = getattr(out, "cost", None)
186196
if c is not None:
187-
total_cost += float(c or 0)
188-
total_in_tokens += int(usage.get("input_tokens", 0) or 0)
189-
total_out_tokens += int(usage.get("output_tokens", 0) or 0)
197+
cost += float(c or 0)
198+
in_tok += int(usage.get("input_tokens", 0) or 0)
199+
out_tok += int(usage.get("output_tokens", 0) or 0)
190200
except Exception:
191201
pass
202+
if in_tok or out_tok or cost:
203+
total_cost += cost
204+
total_in_tokens += in_tok
205+
total_out_tokens += out_tok
206+
return
192207
# Heuristic fallback: estimate tokens from input/output texts and compute cost via env prices
193208
try:
194209
if input_text:
@@ -206,6 +221,38 @@ def _accumulate_usage(input_text: str = "", output_text: str = ""):
206221
except Exception:
207222
pass
208223

224+
def _infer_expression(q: str) -> str:
225+
ql = q.lower()
226+
# Handle "divide X by Y" and "subtract X from Y"
227+
m = re.search(r"\bdivide\s+(\d+(?:\.\d+)?)\s+by\s+(\d+(?:\.\d+)?)\b", ql)
228+
if m:
229+
return f"{m.group(1)}/{m.group(2)}"
230+
m = re.search(r"\bsubtract\s+(\d+(?:\.\d+)?)\s+from\s+(\d+(?:\.\d+)?)\b", ql)
231+
if m:
232+
return f"{m.group(2)}-{m.group(1)}"
233+
# Binary worded ops
234+
m = re.search(r"\b(\d+(?:\.\d+)?)\s*(?:x|times|multiplied by)\s*(\d+(?:\.\d+)?)\b", ql)
235+
if m:
236+
return f"{m.group(1)}*{m.group(2)}"
237+
m = re.search(r"\b(\d+(?:\.\d+)?)\s*(?:plus|add|added to)\s*(\d+(?:\.\d+)?)\b", ql)
238+
if m:
239+
return f"{m.group(1)}+{m.group(2)}"
240+
m = re.search(r"\b(\d+(?:\.\d+)?)\s*(?:minus|subtract|subtracted by)\s*(\d+(?:\.\d+)?)\b", ql)
241+
if m:
242+
return f"{m.group(1)}-{m.group(2)}"
243+
m = re.search(r"\b(\d+(?:\.\d+)?)\s*(?:divide|divided by|over)\s*(\d+(?:\.\d+)?)\b", ql)
244+
if m:
245+
return f"{m.group(1)}/{m.group(2)}"
246+
# Multi-number add/sum
247+
if "add" in ql or "sum" in ql:
248+
nums = [n for n in re.findall(r"\b\d+\b", q)]
249+
if len(nums) >= 2:
250+
return "+".join(nums)
251+
# Fallback: longest math-like substring
252+
candidates = re.findall(r"[0-9\+\-\*/%\(\)\.!\^\s]+", q)
253+
candidates = [c.strip() for c in candidates if any(op in c for op in ["+","-","*","/","%","^","(",")","!"])]
254+
return max(candidates, key=len) if candidates else ""
255+
209256
# Path A: OpenAI-native tool calling using DSPy signatures/adapters.
210257
if self._use_tool_calls:
211258
dspy_tools = to_dspy_tools()
@@ -275,6 +322,13 @@ def _accumulate_usage(input_text: str = "", output_text: str = ""):
275322
# Check finalization.
276323
final = getattr(pred, 'final', None)
277324
if final:
325+
if executed_any:
326+
state.append({
327+
"tool": "⛔️policy_violation",
328+
"args": {"reason": "tool_and_final"},
329+
"observation": "Model returned a final answer alongside a tool call.",
330+
})
331+
continue
278332
if had_policy_violation or had_validation_error:
279333
continue
280334
if must_math and not used_tool(state, "calculator"):
@@ -319,29 +373,15 @@ def _accumulate_usage(input_text: str = "", output_text: str = ""):
319373
if calculators:
320374
parts.append(str(calculators[0]["observation"].get("result")))
321375
elif must_math:
322-
# Last-chance math: infer a simple expression from the question.
323-
import re as _re
324-
ql = question.lower()
325-
if "add" in ql or "sum" in ql:
326-
nums = [int(n) for n in _re.findall(r"\b\d+\b", question)]
327-
if len(nums) >= 2:
328-
res = sum(nums)
376+
expr = _infer_expression(question)
377+
if expr:
378+
try:
379+
res = safe_eval_math(expr)
329380
parts.append(str(res))
330-
# also record as a calculator step for trace parity
331-
state.append({"tool": "calculator", "args": {"expression": "+".join(map(str, nums))}, "observation": {"result": res}})
381+
state.append({"tool": "calculator", "args": {"expression": expr}, "observation": {"result": res}})
332382
tool_calls += 1
333-
if not parts:
334-
candidates = _re.findall(r"[0-9\+\-\*/%\(\)\.!\^\s]+", question)
335-
candidates = [c.strip() for c in candidates if any(op in c for op in ["+","-","*","/","%","^","(",")","!"])]
336-
expr = max(candidates, key=len) if candidates else ""
337-
if expr:
338-
try:
339-
res = safe_eval_math(expr)
340-
parts.append(str(res))
341-
state.append({"tool": "calculator", "args": {"expression": expr}, "observation": {"result": res}})
342-
tool_calls += 1
343-
except Exception:
344-
pass
383+
except Exception:
384+
pass
345385
if nows:
346386
iso = nows[-1]["observation"].get("iso")
347387
if iso:
@@ -413,6 +453,13 @@ def _accumulate_usage(input_text: str = "", output_text: str = ""):
413453
continue
414454

415455
if "final" in decision:
456+
if "tool" in decision:
457+
state.append({
458+
"tool": "⛔️policy_violation",
459+
"args": {"reason": "tool_and_final"},
460+
"observation": "Decision contained both tool and final.",
461+
})
462+
continue
416463
# Enforce tool usage policy: if required tools not yet used, keep planning.
417464
if must_math and not used_tool(state, "calculator"):
418465
state.append({"tool": "⛔️policy_violation", "args": {}, "observation": "Finalize attempted before calculator."})
@@ -430,7 +477,14 @@ def _accumulate_usage(input_text: str = "", output_text: str = ""):
430477
iso = nows[-1]["observation"].get("iso")
431478
if iso:
432479
composed_parts.append(f"UTC: {iso}")
433-
final_text = " | ".join(composed_parts) if composed_parts else decision["final"].get("answer", "")
480+
if composed_parts:
481+
final_text = " | ".join(composed_parts)
482+
else:
483+
final_payload = decision.get("final")
484+
if isinstance(final_payload, dict):
485+
final_text = final_payload.get("answer", "")
486+
else:
487+
final_text = str(final_payload) if final_payload is not None else ""
434488
p = dspy.Prediction(answer=final_text, trace=state)
435489
p.usage = {
436490
"lm_calls": lm_calls,
@@ -478,24 +532,12 @@ def _accumulate_usage(input_text: str = "", output_text: str = ""):
478532
if calc_results:
479533
parts.append(str(calc_results[0]))
480534
if must_math and not parts:
481-
# Last-chance math: try to infer a simple expression from the question.
482-
ql = question.lower()
483-
# If looks like 'add X and Y', sum integers.
484-
import re
485-
if "add" in ql or "sum" in ql:
486-
nums = [int(n) for n in re.findall(r"\b\d+\b", question)]
487-
if len(nums) >= 2:
488-
parts.append(str(sum(nums)))
489-
if not parts:
490-
# Extract longest math-like substring and evaluate.
491-
candidates = re.findall(r"[0-9\+\-\*/%\(\)\.!\^\s]+", question)
492-
candidates = [c.strip() for c in candidates if any(op in c for op in ["+","-","*","/","%","^","(",")","!"])]
493-
expr = max(candidates, key=len) if candidates else ""
494-
if expr:
495-
try:
496-
parts.append(str(safe_eval_math(expr)))
497-
except Exception:
498-
pass
535+
expr = _infer_expression(question)
536+
if expr:
537+
try:
538+
parts.append(str(safe_eval_math(expr)))
539+
except Exception:
540+
pass
499541
if nows:
500542
iso = nows[-1]["observation"].get("iso")
501543
if iso:

micro_agent/config.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,35 @@ def __call__(self, *, prompt: str, **kwargs):
3535
question = qmatch.group(1).strip() if qmatch else prompt
3636
ql = question.lower()
3737
# heuristic: suggest calculator/now/final
38-
if re.search(r"[0-9].*[+\-*/]", question) or any(w in ql for w in [
39-
"add","sum","multiply","divide","compute","calculate","total","power","factorial","!","**","^"
40-
]):
41-
# crude expression extraction
42-
cands = re.findall(r"[0-9\+\-\*/%\(\)\.!\^\s]+", question)
43-
cands = [c.strip() for c in cands if c.strip()]
44-
expr = max(cands, key=len) if cands else "2+2"
38+
if (re.search(r"[0-9].*[+\-*/%]", question) or
39+
re.search(r"\b\d+(?:\.\d+)?\s*(?:x|times|multiplied by|plus|minus|add|added to|subtract|subtracted by|divide|divided by|over)\s*\d+(?:\.\d+)?\b", ql) or
40+
(re.search(r"\d", ql) and any(w in ql for w in [
41+
"add","sum","plus","minus","subtract","multiply","divide","total","power","factorial","compute","calculate","!","**","^"
42+
]))):
43+
expr = None
44+
m = re.search(r"\b(\d+(?:\.\d+)?)\s*(?:x|times|multiplied by)\s*(\d+(?:\.\d+)?)\b", ql)
45+
if m:
46+
expr = f"{m.group(1)}*{m.group(2)}"
47+
if expr is None:
48+
m = re.search(r"\b(\d+(?:\.\d+)?)\s*(?:plus|add|added to)\s*(\d+(?:\.\d+)?)\b", ql)
49+
if m:
50+
expr = f"{m.group(1)}+{m.group(2)}"
51+
if expr is None:
52+
m = re.search(r"\b(\d+(?:\.\d+)?)\s*(?:minus|subtract|subtracted by)\s*(\d+(?:\.\d+)?)\b", ql)
53+
if m:
54+
expr = f"{m.group(1)}-{m.group(2)}"
55+
if expr is None:
56+
m = re.search(r"\b(\d+(?:\.\d+)?)\s*(?:divide|divided by|over)\s*(\d+(?:\.\d+)?)\b", ql)
57+
if m:
58+
expr = f"{m.group(1)}/{m.group(2)}"
59+
# crude expression extraction fallback
60+
if expr is None:
61+
cands = re.findall(r"[0-9\+\-\*/%\(\)\.!\^\s]+", question)
62+
cands = [c.strip() for c in cands if c.strip()]
63+
expr = max(cands, key=len) if cands else "2+2"
4564
return _json.dumps({"tool": {"name": "calculator", "args": {"expression": expr}}})
4665
if ("current time" in ql or "current date" in ql or
47-
re.search(r"\b(time|times|date|dates|utc|now|today|tomorrow|yesterday|timestamp|datetime)\b", ql)):
66+
re.search(r"\b(time|date|utc|now|today|tomorrow|yesterday|timestamp|datetime)\b", ql)):
4867
return _json.dumps({"tool": {"name": "now", "args": {"timezone": "utc"}}})
4968
return _json.dumps({"final": {"answer": "ok"}})
5069
dspy.settings.configure(lm=_MockLM(), track_usage=True)

micro_agent/runtime.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,10 @@ def dump_trace(trace_id: str, question: str, steps: List[Step], answer: str, *,
3939
rec["usage"] = usage
4040
if cost_usd is not None:
4141
rec["cost_usd"] = float(cost_usd)
42+
os.makedirs(TRACES_DIR, exist_ok=True)
4243
path = os.path.join(TRACES_DIR, f"{trace_id}.jsonl")
4344
with open(path, "a", encoding="utf-8") as f:
44-
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
45+
f.write(json.dumps(rec, ensure_ascii=False, default=str) + "\n")
4546
return path
4647

4748
def extract_json_block(text: str) -> str:
@@ -90,7 +91,10 @@ def parse_decision_text(text: str) -> Dict[str, Any]:
9091
block = extract_json_block(text)
9192
# 1) strict json
9293
try:
93-
return json.loads(block)
94+
obj = json.loads(block)
95+
if isinstance(obj, dict):
96+
return obj
97+
raise ValueError("Decision JSON is not an object")
9498
except Exception:
9599
pass
96100
# 2) json-repair (if available)
@@ -99,7 +103,10 @@ def parse_decision_text(text: str) -> Dict[str, Any]:
99103
repaired = json_repair.repair(block)
100104
if isinstance(repaired, dict):
101105
return repaired
102-
return json.loads(repaired)
106+
obj = json.loads(repaired)
107+
if isinstance(obj, dict):
108+
return obj
109+
raise ValueError("Decision JSON is not an object")
103110
except Exception:
104111
pass
105112
# 3) python literal (handles single quotes)

micro_agent/tools.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,21 @@ def _eval_expr(node):
4040
result = ALLOWED_OPS[type(node.op)](lv, rv)
4141
if isinstance(result, complex):
4242
raise ValueError("complex results are not supported")
43+
if isinstance(result, float) and not math.isfinite(result):
44+
raise ValueError("number not finite")
45+
if isinstance(result, (int, float)) and abs(result) > MAX_ABS_NUMBER:
46+
raise ValueError("number too large")
4347
return result
4448
if isinstance(node, ast.UnaryOp) and type(node.op) in ALLOWED_OPS:
4549
v = _eval_expr(node.operand)
4650
if isinstance(v, (int, float)) and abs(v) > MAX_ABS_NUMBER: raise ValueError("number too large")
4751
result = ALLOWED_OPS[type(node.op)](v)
4852
if isinstance(result, complex):
4953
raise ValueError("complex results are not supported")
54+
if isinstance(result, float) and not math.isfinite(result):
55+
raise ValueError("number not finite")
56+
if isinstance(result, (int, float)) and abs(result) > MAX_ABS_NUMBER:
57+
raise ValueError("number too large")
5058
return result
5159
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id in ALLOWED_CALLS:
5260
if len(node.args) != 1:
@@ -61,7 +69,10 @@ def _eval_expr(node):
6169
raise ValueError("factorial requires a non-negative integer")
6270
if arg_int > MAX_FACTORIAL_N:
6371
raise ValueError("factorial too large")
64-
return ALLOWED_CALLS[node.func.id](arg_int)
72+
result = ALLOWED_CALLS[node.func.id](arg_int)
73+
if isinstance(result, (int, float)) and abs(result) > MAX_ABS_NUMBER:
74+
raise ValueError("number too large")
75+
return result
6576
if isinstance(node, ast.Expression): return _eval_expr(node.body)
6677
raise ValueError("Disallowed expression")
6778

tests/test_regressions.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import datetime
2+
import json
13
import pytest
24

35
from micro_agent.config import configure_lm
46
from micro_agent.agent import MicroAgent
57
from micro_agent.tools import safe_eval_math
8+
from micro_agent import runtime
69

710

811
def test_no_false_time_trigger_on_update(monkeypatch):
@@ -21,3 +24,28 @@ def test_factorial_rejects_non_integer():
2124
def test_complex_results_rejected():
2225
with pytest.raises(ValueError):
2326
safe_eval_math("(-1)^(0.5)")
27+
28+
29+
def test_times_is_math_not_time(monkeypatch):
30+
monkeypatch.setenv("LLM_PROVIDER", "mock")
31+
configure_lm()
32+
agent = MicroAgent(max_steps=3)
33+
pred = agent("What is 3 times 4?")
34+
assert "12" in pred.answer
35+
assert any(step.get("tool") == "calculator" for step in (pred.trace or []))
36+
assert not any(step.get("tool") == "now" for step in (pred.trace or []))
37+
38+
39+
def test_dump_trace_serializes_non_json(tmp_path, monkeypatch):
40+
monkeypatch.setattr(runtime, "TRACES_DIR", str(tmp_path))
41+
trace_id = runtime.new_trace_id()
42+
steps = [{"tool": "now", "args": {}, "observation": {"when": datetime.datetime(2020, 1, 1)}}]
43+
path = runtime.dump_trace(trace_id, "q", steps, "a")
44+
with open(path, "r", encoding="utf-8") as f:
45+
rec = json.loads(f.readline())
46+
assert rec["steps"][0]["observation"]["when"].startswith("2020-01-01")
47+
48+
49+
def test_result_magnitude_limit():
50+
with pytest.raises(ValueError):
51+
safe_eval_math("1000000*10000000")

0 commit comments

Comments
 (0)