Skip to content

Commit 59989c3

Browse files
committed
feat(optimize): stricter metric + labeled=0 + filter demos; agent(OpenAI): compose final from tool results to preserve numeric answers; tests passing
1 parent e3d91bc commit 59989c3

File tree

3 files changed

+35
-74
lines changed

3 files changed

+35
-74
lines changed

micro_agent/agent.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,18 @@ def _accumulate_usage():
211211
if must_time and not used_tool(state, "now"):
212212
state.append({"tool": "⛔️policy_violation", "args": {}, "observation": "Finalize before now (OpenAI path)."})
213213
continue
214-
p = dspy.Prediction(answer=final, trace=state)
214+
# Prefer composing from tool results when available to ensure answers include key values.
215+
composed = []
216+
calculators = [s for s in state if s.get("tool") == "calculator" and isinstance(s.get("observation"), dict)]
217+
nows = [s for s in state if s.get("tool") == "now" and isinstance(s.get("observation"), dict)]
218+
if calculators:
219+
composed.append(str(calculators[0]["observation"].get("result")))
220+
if nows:
221+
iso = nows[-1]["observation"].get("iso")
222+
if iso:
223+
composed.append(f"UTC: {iso}")
224+
answer_text = " | ".join(composed) if composed else final
225+
p = dspy.Prediction(answer=answer_text, trace=state)
215226
p.usage = {
216227
"lm_calls": lm_calls,
217228
"tool_calls": tool_calls,

micro_agent/optimize.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -134,16 +134,19 @@ def forward(self, question: str, state: str = "[]", tools=None):
134134
def metric(example, pred, trace):
135135
q = example.get('question', '')
136136
expect = example.get('expect_contains')
137-
score = 0.0
138137
calls = getattr(pred, 'tool_calls', None)
139-
if any(ch.isdigit() for ch in q) and calls:
140-
score += 0.5
141-
if ("time" in q.lower() or "utc" in q.lower()) and calls:
142-
score += 0.5
143138
fin = getattr(pred, 'final', '') or ''
144-
if expect and expect in str(fin):
145-
score += 1.0
146-
return score
139+
140+
# If we know the expected substring (math tasks), require it in final.
141+
if expect:
142+
return 1.0 if (fin and expect in str(fin)) else 0.0
143+
144+
# Otherwise (e.g., time tasks), accept when appropriate tool is used.
145+
if calls and getattr(calls, 'tool_calls', None):
146+
for c in calls.tool_calls:
147+
if getattr(c, 'name', '') == 'now':
148+
return 1.0
149+
return 0.0
147150

148151
# Build trainset Examples
149152
trainset: List[Example] = []
@@ -155,18 +158,23 @@ def metric(example, pred, trace):
155158
ex = ex.with_inputs('question', 'state', 'tools')
156159
trainset.append(ex)
157160

158-
tele = BootstrapFewShot(metric=metric, max_bootstrapped_demos=8, max_labeled_demos=8, max_rounds=1)
161+
tele = BootstrapFewShot(metric=metric, metric_threshold=1.0, max_bootstrapped_demos=8, max_labeled_demos=0, max_rounds=1)
159162
compiled = tele.compile(Planner(), trainset=trainset)
160163

161164
# Extract demos from the compiled predictor
162165
demos = []
163166
for demo in getattr(compiled.decide, 'demos', []) or []:
164167
raw = demo.toDict()
168+
tool_calls = _serialize_tool_calls(raw.get("tool_calls"))
169+
final = raw.get("final")
170+
# Keep only demos that actually contain signals (augmented)
171+
if not tool_calls and not final:
172+
continue
165173
record = {
166174
"question": raw.get("question"),
167175
"state": raw.get("state", "[]"),
168-
"tool_calls": _serialize_tool_calls(raw.get("tool_calls")),
169-
"final": raw.get("final"),
176+
"tool_calls": tool_calls,
177+
"final": final,
170178
}
171179
demos.append(record)
172180

opt/plan_demos.json

Lines changed: 4 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,4 @@
11
[
2-
{
3-
"question": "What's 2*(3+5)? Return only the number.",
4-
"state": "[]",
5-
"tool_calls": [
6-
{
7-
"name": "calculator",
8-
"args": {
9-
"expression": "2*(3+5)"
10-
}
11-
}
12-
],
13-
"final": null
14-
},
152
{
163
"question": "What time is it right now? Use UTC.",
174
"state": "[]",
@@ -26,28 +13,9 @@
2613
"final": null
2714
},
2815
{
29-
"question": "Compute (7**2 + 14) / 5 and explain briefly.",
30-
"state": "[]",
31-
"tool_calls": [
32-
{
33-
"name": "calculator",
34-
"args": {
35-
"expression": "(7**2 + 14) / 5"
36-
}
37-
}
38-
],
39-
"final": null
40-
},
41-
{
42-
"question": "Add 12345 and 67890, then tell me the current date (UTC).",
16+
"question": "What time is it right now? Use UTC.",
4317
"state": "[]",
4418
"tool_calls": [
45-
{
46-
"name": "calculator",
47-
"args": {
48-
"expression": "12345 + 67890"
49-
}
50-
},
5119
{
5220
"name": "now",
5321
"args": {
@@ -58,39 +26,13 @@
5826
"final": null
5927
},
6028
{
61-
"question": "If I spend $12.50 daily for 9 days, what's the total?",
62-
"state": "[]",
63-
"tool_calls": [
64-
{
65-
"name": "calculator",
66-
"args": {
67-
"expression": "12.50 * 9"
68-
}
69-
}
70-
],
71-
"final": null
72-
},
73-
{
74-
"question": "What's 9! / (3!*3!*3!)? Just the integer.",
75-
"state": "[]",
76-
"tool_calls": [
77-
{
78-
"name": "calculator",
79-
"args": {
80-
"expression": "9! / (3!*3!*3!)"
81-
}
82-
}
83-
],
84-
"final": null
85-
},
86-
{
87-
"question": "What's 2*(3+5)? Return only the number.",
29+
"question": "What time is it right now? Use UTC.",
8830
"state": "[]",
8931
"tool_calls": [
9032
{
91-
"name": "calculator",
33+
"name": "now",
9234
"args": {
93-
"expression": "2*(3+5)"
35+
"timezone": "utc"
9436
}
9537
}
9638
],

0 commit comments

Comments
 (0)