Skip to content

Commit 2c0582f

Browse files
committed
Tighten tool validation and math normalization
1 parent f343fb1 commit 2c0582f

File tree

5 files changed

+99
-24
lines changed

5 files changed

+99
-24
lines changed

micro_agent/agent.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,19 @@ def forward(self, question: str):
154154
total_in_tokens = 0
155155
total_out_tokens = 0
156156

157+
def _normalize_text(q: str) -> str:
158+
return (
159+
q.replace("\u00d7", "x")
160+
.replace("\u00f7", "/")
161+
.replace("\u2212", "-")
162+
.replace("\u2013", "-")
163+
.replace("\u2014", "-")
164+
)
165+
157166
def needs_math(q: str) -> bool:
158-
ql = q.lower()
159-
if re.search(r"[0-9].*[+\-*/%]", q):
167+
qn = _normalize_text(q)
168+
ql = qn.lower()
169+
if re.search(r"[0-9].*[+\-*/%]", qn):
160170
return True
161171
if re.search(r"\b\d+(?:\.\d+)?\s*(?:x|times|multiplied by)\s*\d+(?:\.\d+)?\b", ql):
162172
return True
@@ -170,7 +180,7 @@ def needs_math(q: str) -> bool:
170180
return False
171181

172182
def needs_time(q: str) -> bool:
173-
ql = q.lower()
183+
ql = _normalize_text(q).lower()
174184
if "current time" in ql or "current date" in ql:
175185
return True
176186
return re.search(r"\b(time|date|utc|now|today|tomorrow|yesterday|timestamp|datetime)\b", ql) is not None
@@ -222,7 +232,8 @@ def _accumulate_usage(input_text: str = "", output_text: str = ""):
222232
pass
223233

224234
def _infer_expression(q: str) -> str:
225-
ql = q.lower()
235+
qn = _normalize_text(q)
236+
ql = qn.lower()
226237
# Handle "divide X by Y" and "subtract X from Y"
227238
m = re.search(r"\bdivide\s+(\d+(?:\.\d+)?)\s+by\s+(\d+(?:\.\d+)?)\b", ql)
228239
if m:
@@ -245,11 +256,11 @@ def _infer_expression(q: str) -> str:
245256
return f"{m.group(1)}/{m.group(2)}"
246257
# Multi-number add/sum
247258
if "add" in ql or "sum" in ql:
248-
nums = [n for n in re.findall(r"\b\d+\b", q)]
259+
nums = [n for n in re.findall(r"\b\d+\b", qn)]
249260
if len(nums) >= 2:
250261
return "+".join(nums)
251262
# Fallback: longest math-like substring
252-
candidates = re.findall(r"[0-9\+\-\*/%\(\)\.!\^\s]+", q)
263+
candidates = re.findall(r"[0-9\+\-\*/%\(\)\.!\^\s]+", qn)
253264
candidates = [c.strip() for c in candidates if any(op in c for op in ["+","-","*","/","%","^","(",")","!"])]
254265
return max(candidates, key=len) if candidates else ""
255266

@@ -307,6 +318,14 @@ def _infer_expression(q: str) -> str:
307318
continue
308319
# Validate/execute; on validation error, record and continue planning
309320
obs = run_tool(name, args)
321+
if isinstance(obs, dict) and "error" in obs and obs.get("error", "").startswith("Unknown tool"):
322+
had_validation_error = True
323+
state.append({
324+
"tool": "⛔️validation_error",
325+
"args": {"name": name, "args": args},
326+
"observation": obs,
327+
})
328+
continue
310329
if isinstance(obs, dict) and "error" in obs and "validation" in obs.get("error", ""):
311330
had_validation_error = True
312331
state.append({
@@ -506,6 +525,13 @@ def _infer_expression(q: str) -> str:
506525
name = str(tool_desc)
507526
args = decision.get("args", {}) or {}
508527
obs = run_tool(name, args)
528+
if isinstance(obs, dict) and "error" in obs and obs.get("error", "").startswith("Unknown tool"):
529+
state.append({
530+
"tool": "⛔️validation_error",
531+
"args": {"name": name, "args": args},
532+
"observation": obs,
533+
})
534+
continue
509535
if isinstance(obs, dict) and "error" in obs and "validation" in obs.get("error", ""):
510536
# second-chance: record detailed schema hint in state and continue planning
511537
schema = TOOLS.get(name).schema if name in TOOLS else {}

micro_agent/config.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,15 @@ def __call__(self, *, prompt: str, **kwargs):
3333
import re, json as _json
3434
qmatch = re.search(r"Question:\s*(.*)", prompt, re.S)
3535
question = qmatch.group(1).strip() if qmatch else prompt
36-
ql = question.lower()
36+
qn = (question
37+
.replace("\u00d7", "x")
38+
.replace("\u00f7", "/")
39+
.replace("\u2212", "-")
40+
.replace("\u2013", "-")
41+
.replace("\u2014", "-"))
42+
ql = qn.lower()
3743
# heuristic: suggest calculator/now/final
38-
if (re.search(r"[0-9].*[+\-*/%]", question) or
44+
if (re.search(r"[0-9].*[+\-*/%]", qn) or
3945
re.search(r"\b\d+(?:\.\d+)?\s*(?:x|times|multiplied by|plus|minus|add|added to|subtract|subtracted by|divide|divided by|over)\s*\d+(?:\.\d+)?\b", ql) or
4046
(re.search(r"\d", ql) and any(w in ql for w in [
4147
"add","sum","plus","minus","subtract","multiply","divide","total","power","factorial","compute","calculate","!","**","^"
@@ -58,7 +64,7 @@ def __call__(self, *, prompt: str, **kwargs):
5864
expr = f"{m.group(1)}/{m.group(2)}"
5965
# crude expression extraction fallback
6066
if expr is None:
61-
cands = re.findall(r"[0-9\+\-\*/%\(\)\.!\^\s]+", question)
67+
cands = re.findall(r"[0-9\+\-\*/%\(\)\.!\^\s]+", qn)
6268
cands = [c.strip() for c in cands if c.strip()]
6369
expr = max(cands, key=len) if cands else "2+2"
6470
return _json.dumps({"tool": {"name": "calculator", "args": {"expression": expr}}})

micro_agent/costs.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,19 +67,29 @@ def estimate_cost_usd(input_tokens: int, output_tokens: int, model: str, provide
6767
def estimate_prediction_cost(question: str, trace: Any, answer: str, usage: Dict[str, Any]) -> Dict[str, Any]:
6868
"""Estimate token usage and USD cost for a single prediction.
6969
70-
Heuristic: input tokens ~= lm_calls * tokens(question) + tokens(str(trace))
71-
output tokens ~= tokens(answer)
70+
If usage provides token counts, prefer them. Otherwise fall back to a heuristic:
71+
input tokens ~= lm_calls * tokens(question) + tokens(str(trace))
72+
output tokens ~= tokens(answer)
7273
"""
73-
provider = (usage or {}).get("provider") or "openai"
74-
model = (usage or {}).get("model") or "gpt-4o-mini"
75-
lm_calls = int((usage or {}).get("lm_calls", 0) or 0)
74+
usage = usage or {}
75+
provider = usage.get("provider") or "openai"
76+
model = usage.get("model") or "gpt-4o-mini"
77+
lm_calls = int(usage.get("lm_calls", 0) or 0)
7678

77-
q_tokens = estimate_tokens(str(question or ""), model)
78-
trace_tokens = estimate_tokens(str(trace or ""), model)
79-
ans_tokens = estimate_tokens(str(answer or ""), model)
80-
in_tokens = lm_calls * q_tokens + trace_tokens
81-
out_tokens = ans_tokens
82-
cost = estimate_cost_usd(in_tokens, out_tokens, model=model, provider=provider)
79+
in_tokens = int(usage.get("input_tokens", 0) or 0)
80+
out_tokens = int(usage.get("output_tokens", 0) or 0)
81+
if in_tokens == 0 and out_tokens == 0:
82+
q_tokens = estimate_tokens(str(question or ""), model)
83+
trace_tokens = estimate_tokens(str(trace or ""), model)
84+
ans_tokens = estimate_tokens(str(answer or ""), model)
85+
in_tokens = lm_calls * q_tokens + trace_tokens
86+
out_tokens = ans_tokens
87+
88+
cost = usage.get("cost")
89+
if cost is None or cost == 0:
90+
cost = estimate_cost_usd(in_tokens, out_tokens, model=model, provider=provider)
91+
else:
92+
cost = float(cost)
8393
return {
8494
"input_tokens": in_tokens,
8595
"output_tokens": out_tokens,

micro_agent/tools.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,12 @@ def spec(self) -> Dict[str, Any]:
2929
def _eval_expr(node):
3030
# Python 3.10+: numeric literals appear as ast.Constant
3131
if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)) and not isinstance(node.value, bool):
32-
return node.value
32+
v = node.value
33+
if isinstance(v, float) and not math.isfinite(v):
34+
raise ValueError("number not finite")
35+
if abs(v) > MAX_ABS_NUMBER:
36+
raise ValueError("number too large")
37+
return v
3338
if isinstance(node, ast.BinOp) and type(node.op) in ALLOWED_OPS:
3439
lv, rv = _eval_expr(node.left), _eval_expr(node.right)
3540
if isinstance(lv, (int, float)) and abs(lv) > MAX_ABS_NUMBER: raise ValueError("number too large")
@@ -80,6 +85,15 @@ def preprocess_math(expr: str) -> str:
8085
# Replace simple factorial forms like 9! or 12! with fact(9) / fact(12)
8186
expr = str(expr or "").strip()
8287
expr = re.sub(r"(\d+)\!", r"fact(\1)", expr)
88+
# Normalize common unicode operators
89+
expr = (
90+
expr
91+
.replace("\u00d7", "*") # ×
92+
.replace("\u00f7", "/") # ÷
93+
.replace("\u2212", "-") # −
94+
.replace("\u2013", "-") # –
95+
.replace("\u2014", "-") # —
96+
)
8397
# Replace caret ^ with exponentiation
8498
expr = expr.replace("^", "**")
8599
# Trim trailing punctuation that commonly slips from prose
@@ -103,7 +117,9 @@ def tool_calculator(args: Dict[str, Any]):
103117

104118
def tool_now(args: Dict[str, Any]):
105119
tz = str(args.get("timezone", "local")).lower()
106-
now = datetime.datetime.now(datetime.timezone.utc) if tz == "utc" else datetime.datetime.now()
120+
if tz not in {"utc", "local"}:
121+
raise ValueError("timezone must be 'utc' or 'local'")
122+
now = datetime.datetime.now(datetime.timezone.utc) if tz == "utc" else datetime.datetime.now().astimezone()
107123
return {"iso": now.isoformat(timespec="seconds")}
108124

109125
def _load_plugins():
@@ -130,13 +146,13 @@ def _load_plugins():
130146
"calculator": Tool(
131147
"calculator",
132148
"Evaluate arithmetic expressions. Schema: {expression: string}. Supports +,-,*,/,**,%, //, parentheses.",
133-
{"type": "object", "properties": {"expression": {"type": "string"}}, "required": ["expression"]},
149+
{"type": "object", "properties": {"expression": {"type": "string"}}, "required": ["expression"], "additionalProperties": False},
134150
tool_calculator
135151
),
136152
"now": Tool(
137153
"now",
138154
"Return the current timestamp. Optional: {timezone: 'utc'|'local'}",
139-
{"type": "object", "properties": {"timezone": {"type": "string"}}, "required": []},
155+
{"type": "object", "properties": {"timezone": {"type": "string", "enum": ["utc", "local"]}}, "required": [], "additionalProperties": False},
140156
tool_now
141157
),
142158
}

tests/test_regressions.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import datetime
22
import json
3+
import re
34
import pytest
45

56
from micro_agent.config import configure_lm
67
from micro_agent.agent import MicroAgent
78
from micro_agent.tools import safe_eval_math
89
from micro_agent import runtime
10+
from micro_agent.tools import run_tool
911

1012

1113
def test_no_false_time_trigger_on_update(monkeypatch):
@@ -49,3 +51,18 @@ def test_dump_trace_serializes_non_json(tmp_path, monkeypatch):
4951
def test_result_magnitude_limit():
5052
with pytest.raises(ValueError):
5153
safe_eval_math("1000000*10000000")
54+
55+
56+
def test_unicode_multiply():
57+
assert safe_eval_math("3\u00d74") == 12
58+
59+
60+
def test_now_local_has_offset():
61+
obs = run_tool("now", {"timezone": "local"})
62+
assert "iso" in obs
63+
assert re.search(r"[+-]\d\d:\d\d$", obs["iso"])
64+
65+
66+
def test_now_invalid_timezone_validation():
67+
obs = run_tool("now", {"timezone": "pst"})
68+
assert "error" in obs and "validation" in obs["error"]

0 commit comments

Comments
 (0)