Skip to content

Commit 023d220

Browse files
committed
Fix tool-call policy, parsing, and config fallbacks
1 parent cfd2d7a commit 023d220

File tree

6 files changed

+146
-28
lines changed

6 files changed

+146
-28
lines changed

micro_agent/agent.py

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,7 @@ def __init__(self, max_steps: int = 6, use_tool_calls: bool | None = None):
1919
self.finalize = None # fallback finalize handled via LM prompt
2020
self._tool_list = [t.spec() for t in TOOLS.values()]
2121
self.max_steps = max_steps
22-
self._provider = None
23-
try:
24-
self._provider = (self.lm.model.split("/", 1)[0] if getattr(self.lm, "model", None) else None)
25-
except Exception:
26-
self._provider = None
22+
self._provider = self._infer_provider(self.lm)
2723
# Determine function-calls mode
2824
env_override = os.getenv("USE_TOOL_CALLS")
2925
if isinstance(use_tool_calls, bool):
@@ -37,10 +33,36 @@ def __init__(self, max_steps: int = 6, use_tool_calls: bool | None = None):
3733
try:
3834
from dspy.adapters import JSONAdapter
3935
dspy.settings.configure(adapter=JSONAdapter())
36+
if to_dspy_tools():
37+
self.planner = dspy.Predict(PlanWithTools)
38+
self._load_compiled_demos()
39+
else:
40+
self._use_tool_calls = False
4041
except Exception:
41-
pass
42-
self.planner = dspy.Predict(PlanWithTools)
43-
self._load_compiled_demos()
42+
self._use_tool_calls = False
43+
44+
def _infer_provider(self, lm) -> str | None:
45+
try:
46+
prov = getattr(lm, "provider", None) or getattr(lm, "_provider", None)
47+
if isinstance(prov, str) and prov.strip():
48+
return prov.strip().lower()
49+
except Exception:
50+
pass
51+
try:
52+
cls_name = lm.__class__.__name__.lower()
53+
if "openai" in cls_name:
54+
return "openai"
55+
if "ollama" in cls_name:
56+
return "ollama"
57+
except Exception:
58+
pass
59+
try:
60+
model = getattr(lm, "model", None)
61+
if isinstance(model, str) and "/" in model:
62+
return model.split("/", 1)[0].lower()
63+
except Exception:
64+
pass
65+
return None
4466

4567
def _load_compiled_demos(self):
4668
import json as _json
@@ -142,7 +164,9 @@ def needs_math(q: str) -> bool:
142164

143165
def needs_time(q: str) -> bool:
144166
ql = q.lower()
145-
return any(w in ql for w in ["time", "date", "utc", "current time", "now"])
167+
if "current time" in ql or "current date" in ql:
168+
return True
169+
return re.search(r"\b(time|times|date|dates|utc|now|today|tomorrow|yesterday|timestamp|datetime)\b", ql) is not None
146170

147171
def used_tool(state, name: str) -> bool:
148172
return any(step.get("tool") == name for step in state)
@@ -216,8 +240,19 @@ def _accumulate_usage(input_text: str = "", output_text: str = ""):
216240
# If tool calls are proposed, execute them.
217241
calls = getattr(pred, 'tool_calls', None)
218242
executed_any = False
243+
had_validation_error = False
244+
had_policy_violation = False
219245
if calls and getattr(calls, 'tool_calls', None):
220-
for call in calls.tool_calls:
246+
call_list = list(calls.tool_calls)
247+
if len(call_list) > 1:
248+
had_policy_violation = True
249+
state.append({
250+
"tool": "⛔️policy_violation",
251+
"args": {"reason": "multiple_tool_calls", "count": len(call_list)},
252+
"observation": "Model returned multiple tool calls in one step; executing only the first.",
253+
})
254+
call_list = call_list[:1]
255+
for call in call_list:
221256
try:
222257
name = getattr(call, 'name')
223258
args = getattr(call, 'args') or {}
@@ -226,6 +261,7 @@ def _accumulate_usage(input_text: str = "", output_text: str = ""):
226261
# Validate/execute; on validation error, record and continue planning
227262
obs = run_tool(name, args)
228263
if isinstance(obs, dict) and "error" in obs and "validation" in obs.get("error", ""):
264+
had_validation_error = True
229265
state.append({
230266
"tool": "⛔️validation_error",
231267
"args": {"name": name, "args": args},
@@ -239,6 +275,8 @@ def _accumulate_usage(input_text: str = "", output_text: str = ""):
239275
# Check finalization.
240276
final = getattr(pred, 'final', None)
241277
if final:
278+
if had_policy_violation or had_validation_error:
279+
continue
242280
if must_math and not used_tool(state, "calculator"):
243281
state.append({"tool": "⛔️policy_violation", "args": {}, "observation": "Finalize before calculator (OpenAI path)."})
244282
# If tools were suggested and executed this step, iterate; else force tool suggestion by continuing.
@@ -399,6 +437,9 @@ def _accumulate_usage(input_text: str = "", output_text: str = ""):
399437
"tool_calls": tool_calls,
400438
"provider": self._provider,
401439
"model": getattr(self.lm, "model", None),
440+
"cost": total_cost,
441+
"input_tokens": total_in_tokens,
442+
"output_tokens": total_out_tokens,
402443
}
403444
return p
404445

micro_agent/config.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def __call__(self, *, prompt: str, **kwargs):
4343
cands = [c.strip() for c in cands if c.strip()]
4444
expr = max(cands, key=len) if cands else "2+2"
4545
return _json.dumps({"tool": {"name": "calculator", "args": {"expression": expr}}})
46-
if any(w in ql for w in ["time","date","utc","current time","now"]):
46+
if ("current time" in ql or "current date" in ql or
47+
re.search(r"\b(time|times|date|dates|utc|now|today|tomorrow|yesterday|timestamp|datetime)\b", ql)):
4748
return _json.dumps({"tool": {"name": "now", "args": {"timezone": "utc"}}})
4849
return _json.dumps({"final": {"answer": "ok"}})
4950
dspy.settings.configure(lm=_MockLM(), track_usage=True)
@@ -75,10 +76,14 @@ def _try(name, fn):
7576

7677
# Option 2: OpenAI (default)
7778
openai_model = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
78-
if _try("dspy.OpenAI", lambda: dspy.OpenAI(model=openai_model, temperature=temperature, max_tokens=max_tokens)):
79-
return
80-
if _try("dspy.LM(openai/<model>)", lambda: dspy.LM(f"openai/{openai_model}")):
81-
return
79+
openai_key = os.getenv("OPENAI_API_KEY")
80+
if openai_key:
81+
if _try("dspy.OpenAI", lambda: dspy.OpenAI(model=openai_model, temperature=temperature, max_tokens=max_tokens)):
82+
return
83+
if _try("dspy.LM(openai/<model>)", lambda: dspy.LM(f"openai/{openai_model}")):
84+
return
85+
else:
86+
tried.append(("openai", "missing OPENAI_API_KEY"))
8287

8388
# If we got here, all backends failed: fall back to mock
8489
class _FallbackMockLM:

micro_agent/runtime.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,40 @@ def dump_trace(trace_id: str, question: str, steps: List[Step], answer: str, *,
4444
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
4545
return path
4646

47-
_JSON_RE = re.compile(r"\{.*\}", re.S)
48-
4947
def extract_json_block(text: str) -> str:
5048
"""
5149
Extract the first {...} block to survive models adding prose or code fences.
5250
"""
53-
m = _JSON_RE.search(text)
54-
if not m:
55-
raise ValueError(f"No JSON object found in: {text[:200]!r}")
56-
return m.group(0)
51+
if not text:
52+
raise ValueError("No JSON object found in empty text")
53+
start = None
54+
depth = 0
55+
in_str = False
56+
escape = False
57+
for i, ch in enumerate(text):
58+
if start is None:
59+
if ch == "{":
60+
start = i
61+
depth = 1
62+
continue
63+
if in_str:
64+
if escape:
65+
escape = False
66+
elif ch == "\\":
67+
escape = True
68+
elif ch == "\"":
69+
in_str = False
70+
continue
71+
if ch == "\"":
72+
in_str = True
73+
continue
74+
if ch == "{":
75+
depth += 1
76+
elif ch == "}":
77+
depth -= 1
78+
if depth == 0:
79+
return text[start:i + 1]
80+
raise ValueError(f"No JSON object found in: {text[:200]!r}")
5781

5882
def parse_decision_text(text: str) -> Dict[str, Any]:
5983
"""Parse a model decision string into a dict.
@@ -73,6 +97,8 @@ def parse_decision_text(text: str) -> Dict[str, Any]:
7397
if json_repair is not None:
7498
try:
7599
repaired = json_repair.repair(block)
100+
if isinstance(repaired, dict):
101+
return repaired
76102
return json.loads(repaired)
77103
except Exception:
78104
pass

micro_agent/signatures.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
from __future__ import annotations
22
import dspy
3-
from dspy.adapters import Tool as DSpyTool, ToolCalls
3+
try:
4+
from dspy.adapters import Tool as DSpyTool, ToolCalls
5+
except Exception:
6+
try:
7+
from dspy.adapters.types import Tool as DSpyTool, ToolCalls # type: ignore
8+
except Exception:
9+
DSpyTool = object # type: ignore
10+
ToolCalls = object # type: ignore
411

512
class PlanOrAct(dspy.Signature):
613
"""Decide next step: either call a tool with JSON args or finalize.

micro_agent/tools.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def spec(self) -> Dict[str, Any]:
2828
ALLOWED_CALLS = {"fact": lambda x: math.factorial(int(x))}
2929
def _eval_expr(node):
3030
# Python 3.10+: numeric literals appear as ast.Constant
31-
if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
31+
if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)) and not isinstance(node.value, bool):
3232
return node.value
3333
if isinstance(node, ast.BinOp) and type(node.op) in ALLOWED_OPS:
3434
lv, rv = _eval_expr(node.left), _eval_expr(node.right)
@@ -37,18 +37,31 @@ def _eval_expr(node):
3737
if isinstance(node.op, ast.Pow):
3838
if isinstance(rv, (int, float)) and abs(rv) > MAX_EXPONENT:
3939
raise ValueError("exponent too large")
40-
return ALLOWED_OPS[type(node.op)](lv, rv)
40+
result = ALLOWED_OPS[type(node.op)](lv, rv)
41+
if isinstance(result, complex):
42+
raise ValueError("complex results are not supported")
43+
return result
4144
if isinstance(node, ast.UnaryOp) and type(node.op) in ALLOWED_OPS:
4245
v = _eval_expr(node.operand)
4346
if isinstance(v, (int, float)) and abs(v) > MAX_ABS_NUMBER: raise ValueError("number too large")
44-
return ALLOWED_OPS[type(node.op)](v)
47+
result = ALLOWED_OPS[type(node.op)](v)
48+
if isinstance(result, complex):
49+
raise ValueError("complex results are not supported")
50+
return result
4551
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id in ALLOWED_CALLS:
4652
if len(node.args) != 1:
4753
raise ValueError("Invalid arguments")
4854
arg = _eval_expr(node.args[0])
49-
if isinstance(arg, (int, float)) and arg > MAX_FACTORIAL_N:
55+
if not isinstance(arg, (int, float)) or isinstance(arg, bool):
56+
raise ValueError("factorial requires a number")
57+
if isinstance(arg, float) and not arg.is_integer():
58+
raise ValueError("factorial requires an integer")
59+
arg_int = int(arg)
60+
if arg_int < 0:
61+
raise ValueError("factorial requires a non-negative integer")
62+
if arg_int > MAX_FACTORIAL_N:
5063
raise ValueError("factorial too large")
51-
return ALLOWED_CALLS[node.func.id](arg)
64+
return ALLOWED_CALLS[node.func.id](arg_int)
5265
if isinstance(node, ast.Expression): return _eval_expr(node.body)
5366
raise ValueError("Disallowed expression")
5467

@@ -68,7 +81,10 @@ def safe_eval_math(expr: str) -> float:
6881
# cap complexity
6982
if sum(1 for _ in ast.walk(tree)) > MAX_ALLOWED_OPS_NODES:
7083
raise ValueError("expression too complex")
71-
return _eval_expr(tree)
84+
result = _eval_expr(tree)
85+
if isinstance(result, complex):
86+
raise ValueError("complex results are not supported")
87+
return result
7288

7389
def tool_calculator(args: Dict[str, Any]):
7490
expr = str(args.get("expression", "")).strip()

tests/test_regressions.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
3+
from micro_agent.config import configure_lm
4+
from micro_agent.agent import MicroAgent
5+
from micro_agent.tools import safe_eval_math
6+
7+
8+
def test_no_false_time_trigger_on_update(monkeypatch):
9+
monkeypatch.setenv("LLM_PROVIDER", "mock")
10+
configure_lm()
11+
agent = MicroAgent(max_steps=2)
12+
pred = agent("Please update the docs.")
13+
assert not any(step.get("tool") == "now" for step in (pred.trace or []))
14+
15+
16+
def test_factorial_rejects_non_integer():
17+
with pytest.raises(ValueError):
18+
safe_eval_math("fact(3.5)")
19+
20+
21+
def test_complex_results_rejected():
22+
with pytest.raises(ValueError):
23+
safe_eval_math("(-1)^(0.5)")

0 commit comments

Comments
 (0)