Skip to content

Commit be7661d

Browse files
committed
Isolate usage tracking and unify trace serialization
1 parent 83f359c commit be7661d

File tree

4 files changed

+42
-19
lines changed

4 files changed

+42
-19
lines changed

micro_agent/agent.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,19 @@ class MicroAgent(dspy.Module):
1212
The "agent framework": ~100 LOC.
1313
Plan -> (optional) tool -> observe -> loop -> finalize.
1414
"""
15-
def __init__(self, max_steps: int = 6, use_tool_calls: bool | None = None):
15+
def __init__(self, max_steps: int = 6, use_tool_calls: bool | None = None, use_global_trace: bool | None = None):
1616
super().__init__()
1717
# Use LM directly for robust JSON handling across providers.
1818
self.lm = dspy.settings.lm
1919
self.finalize = None # fallback finalize handled via LM prompt
2020
self._tool_list = [t.spec() for t in TOOLS.values()]
2121
self.max_steps = max_steps
2222
self._provider = self._infer_provider(self.lm)
23+
if isinstance(use_global_trace, bool):
24+
self._use_global_trace = use_global_trace
25+
else:
26+
env_gt = os.getenv("MICRO_AGENT_USE_GLOBAL_TRACE")
27+
self._use_global_trace = env_gt.strip().lower() not in {"0", "false", "no", "off"} if env_gt else True
2328
# Determine function-calls mode
2429
env_override = os.getenv("USE_TOOL_CALLS")
2530
if isinstance(use_tool_calls, bool):
@@ -205,17 +210,18 @@ def _accumulate_usage(input_text: str = "", output_text: str = ""):
205210
in_tok = 0
206211
out_tok = 0
207212
cost = 0.0
208-
try:
209-
for _, _, out in dspy.settings.trace[-1:]:
210-
usage = getattr(out, "usage", None) or {}
211-
nonlocal total_cost, total_in_tokens, total_out_tokens
212-
c = getattr(out, "cost", None)
213-
if c is not None:
214-
cost += float(c or 0)
215-
in_tok += int(usage.get("input_tokens", 0) or 0)
216-
out_tok += int(usage.get("output_tokens", 0) or 0)
217-
except Exception:
218-
pass
213+
if self._use_global_trace:
214+
try:
215+
for _, _, out in dspy.settings.trace[-1:]:
216+
usage = getattr(out, "usage", None) or {}
217+
nonlocal total_cost, total_in_tokens, total_out_tokens
218+
c = getattr(out, "cost", None)
219+
if c is not None:
220+
cost += float(c or 0)
221+
in_tok += int(usage.get("input_tokens", 0) or 0)
222+
out_tok += int(usage.get("output_tokens", 0) or 0)
223+
except Exception:
224+
pass
219225
if in_tok or out_tok or cost:
220226
total_cost += cost
221227
total_in_tokens += in_tok

micro_agent/runtime.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@
1313
def _get_traces_dir() -> str:
1414
return os.getenv("TRACES_DIR", TRACES_DIR)
1515

16+
def to_jsonable(obj: Any) -> Any:
17+
"""Coerce arbitrary objects to JSON-serializable structures."""
18+
try:
19+
return json.loads(json.dumps(obj, ensure_ascii=False, default=str))
20+
except Exception:
21+
return str(obj)
22+
1623
class Step(TypedDict):
1724
tool: str
1825
args: Dict[str, Any]
@@ -44,9 +51,10 @@ def dump_trace(trace_id: str, question: str, steps: List[Step], answer: str, *,
4451
rec["cost_usd"] = float(cost_usd)
4552
traces_dir = _get_traces_dir()
4653
os.makedirs(traces_dir, exist_ok=True)
54+
rec_jsonable = to_jsonable(rec)
4755
path = os.path.join(traces_dir, f"{trace_id}.jsonl")
4856
with open(path, "a", encoding="utf-8") as f:
49-
f.write(json.dumps(rec, ensure_ascii=False, default=str) + "\n")
57+
f.write(json.dumps(rec_jsonable, ensure_ascii=False) + "\n")
5058
return path
5159

5260
def extract_json_block(text: str) -> str:

micro_agent/server.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,28 @@
11
from __future__ import annotations
22
from fastapi import FastAPI, HTTPException
33
from fastapi.middleware.cors import CORSMiddleware
4-
from fastapi.encoders import jsonable_encoder
54
import os, json, re
65
from threading import Lock
76
from pydantic import BaseModel
87
from .costs import estimate_prediction_cost
98
from importlib.metadata import version as _pkg_version, PackageNotFoundError
109
from .config import configure_lm
1110
from .agent import MicroAgent
12-
from .runtime import dump_trace, new_trace_id
11+
from .runtime import dump_trace, new_trace_id, to_jsonable
1312
from .logging_setup import setup_logging
1413

1514
app = FastAPI(title="DSPy Micro Agent")
15+
origins_env = os.getenv("MICRO_AGENT_CORS_ORIGINS", "*").strip()
16+
if origins_env == "*":
17+
allow_origins = ["*"]
18+
allow_credentials = False
19+
else:
20+
allow_origins = [o.strip() for o in origins_env.split(",") if o.strip()]
21+
allow_credentials = os.getenv("MICRO_AGENT_CORS_CREDENTIALS", "0").strip().lower() in {"1", "true", "yes", "on"}
1622
app.add_middleware(
1723
CORSMiddleware,
18-
allow_origins=["*"],
19-
allow_credentials=True,
24+
allow_origins=allow_origins,
25+
allow_credentials=allow_credentials,
2026
allow_methods=["*"],
2127
allow_headers=["*"],
2228
)
@@ -63,7 +69,9 @@ def ask(req: AskRequest):
6369
raise HTTPException(status_code=400, detail="max_steps must be between 1 and 20")
6470

6571
def _call_agent():
66-
agent = _agent if req.use_tool_calls is None and req.max_steps == _agent.max_steps else MicroAgent(max_steps=req.max_steps, use_tool_calls=req.use_tool_calls)
72+
if _serialize and req.use_tool_calls is None and req.max_steps == _agent.max_steps:
73+
return _agent(question)
74+
agent = MicroAgent(max_steps=req.max_steps, use_tool_calls=req.use_tool_calls, use_global_trace=_serialize)
6775
return agent(question)
6876

6977
if _serialize:
@@ -76,7 +84,7 @@ def _call_agent():
7684
usage = getattr(pred, "usage", {}) or {}
7785
est = estimate_prediction_cost(question, pred.trace, pred.answer, usage)
7886
path = dump_trace(trace_id, question, pred.trace, pred.answer, usage=usage, cost_usd=est.get("cost_usd"))
79-
steps = jsonable_encoder(pred.trace)
87+
steps = to_jsonable(pred.trace)
8088
return AskResponse(answer=pred.answer, trace_id=trace_id, trace_path=path, steps=steps, usage=usage, cost_usd=est.get("cost_usd"))
8189

8290
@app.get("/trace/{trace_id}")

tests/test_regressions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def test_dump_trace_serializes_non_json(tmp_path, monkeypatch):
4646
with open(path, "r", encoding="utf-8") as f:
4747
rec = json.loads(f.readline())
4848
assert rec["steps"][0]["observation"]["when"].startswith("2020-01-01")
49+
assert rec["steps"] == runtime.to_jsonable(steps)
4950

5051

5152
def test_result_magnitude_limit():

0 commit comments

Comments
 (0)