Skip to content

Commit 54650a6

Browse files
committed
Harden server validation and trace handling
1 parent 2c0582f commit 54650a6

File tree

4 files changed

+64
-9
lines changed

4 files changed

+64
-9
lines changed

micro_agent/runtime.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
TRACES_DIR = os.getenv("TRACES_DIR", "traces")
1111
os.makedirs(TRACES_DIR, exist_ok=True)
1212

13+
def _get_traces_dir() -> str:
14+
return os.getenv("TRACES_DIR", TRACES_DIR)
15+
1316
class Step(TypedDict):
1417
tool: str
1518
args: Dict[str, Any]
@@ -39,8 +42,9 @@ def dump_trace(trace_id: str, question: str, steps: List[Step], answer: str, *,
3942
rec["usage"] = usage
4043
if cost_usd is not None:
4144
rec["cost_usd"] = float(cost_usd)
42-
os.makedirs(TRACES_DIR, exist_ok=True)
43-
path = os.path.join(TRACES_DIR, f"{trace_id}.jsonl")
45+
traces_dir = _get_traces_dir()
46+
os.makedirs(traces_dir, exist_ok=True)
47+
path = os.path.join(traces_dir, f"{trace_id}.jsonl")
4448
with open(path, "a", encoding="utf-8") as f:
4549
f.write(json.dumps(rec, ensure_ascii=False, default=str) + "\n")
4650
return path

micro_agent/server.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
22
from fastapi import FastAPI, HTTPException
33
from fastapi.middleware.cors import CORSMiddleware
4-
import os, json
4+
from fastapi.encoders import jsonable_encoder
5+
import os, json, re
6+
from threading import Lock
57
from pydantic import BaseModel
68
from .costs import estimate_prediction_cost
79
from importlib.metadata import version as _pkg_version, PackageNotFoundError
@@ -35,6 +37,8 @@ class AskResponse(BaseModel):
3537
setup_logging()
3638
configure_lm()
3739
_agent = MicroAgent()
40+
_agent_lock = Lock()
41+
_serialize = os.getenv("MICRO_AGENT_SERIALIZE", "1").strip().lower() not in {"0", "false", "no", "off"}
3842

3943
@app.get("/healthz")
4044
def healthz():
@@ -52,16 +56,33 @@ def health():
5256

5357
@app.post("/ask", response_model=AskResponse)
5458
def ask(req: AskRequest):
55-
agent = _agent if req.use_tool_calls is None and req.max_steps == _agent.max_steps else MicroAgent(max_steps=req.max_steps, use_tool_calls=req.use_tool_calls)
56-
pred = agent(req.question)
59+
question = (req.question or "").strip()
60+
if not question:
61+
raise HTTPException(status_code=400, detail="question must be a non-empty string")
62+
if req.max_steps < 1 or req.max_steps > 20:
63+
raise HTTPException(status_code=400, detail="max_steps must be between 1 and 20")
64+
65+
def _call_agent():
66+
agent = _agent if req.use_tool_calls is None and req.max_steps == _agent.max_steps else MicroAgent(max_steps=req.max_steps, use_tool_calls=req.use_tool_calls)
67+
return agent(question)
68+
69+
if _serialize:
70+
with _agent_lock:
71+
pred = _call_agent()
72+
else:
73+
pred = _call_agent()
74+
5775
trace_id = new_trace_id()
5876
usage = getattr(pred, "usage", {}) or {}
59-
est = estimate_prediction_cost(req.question, pred.trace, pred.answer, usage)
60-
path = dump_trace(trace_id, req.question, pred.trace, pred.answer, usage=usage, cost_usd=est.get("cost_usd"))
61-
return AskResponse(answer=pred.answer, trace_id=trace_id, trace_path=path, steps=pred.trace, usage=usage, cost_usd=est.get("cost_usd"))
77+
est = estimate_prediction_cost(question, pred.trace, pred.answer, usage)
78+
path = dump_trace(trace_id, question, pred.trace, pred.answer, usage=usage, cost_usd=est.get("cost_usd"))
79+
steps = jsonable_encoder(pred.trace)
80+
return AskResponse(answer=pred.answer, trace_id=trace_id, trace_path=path, steps=steps, usage=usage, cost_usd=est.get("cost_usd"))
6281

6382
@app.get("/trace/{trace_id}")
6483
def get_trace(trace_id: str):
84+
if not re.fullmatch(r"[0-9a-f]{32}", trace_id, flags=re.IGNORECASE):
85+
raise HTTPException(status_code=400, detail="Invalid trace id format")
6586
traces_dir = os.getenv("TRACES_DIR", "traces")
6687
path = os.path.join(traces_dir, f"{trace_id}.jsonl")
6788
if not os.path.exists(path):

tests/test_regressions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_times_is_math_not_time(monkeypatch):
3939

4040

4141
def test_dump_trace_serializes_non_json(tmp_path, monkeypatch):
42-
monkeypatch.setattr(runtime, "TRACES_DIR", str(tmp_path))
42+
monkeypatch.setenv("TRACES_DIR", str(tmp_path))
4343
trace_id = runtime.new_trace_id()
4444
steps = [{"tool": "now", "args": {}, "observation": {"when": datetime.datetime(2020, 1, 1)}}]
4545
path = runtime.dump_trace(trace_id, "q", steps, "a")

tests/test_server.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import importlib
2+
3+
from fastapi.testclient import TestClient
4+
5+
6+
def _client(monkeypatch):
7+
monkeypatch.setenv("LLM_PROVIDER", "mock")
8+
import micro_agent.server as server
9+
importlib.reload(server)
10+
return TestClient(server.app)
11+
12+
13+
def test_ask_rejects_empty_question(monkeypatch):
14+
client = _client(monkeypatch)
15+
resp = client.post("/ask", json={"question": " ", "max_steps": 2})
16+
assert resp.status_code == 400
17+
18+
19+
def test_ask_max_steps_bounds(monkeypatch):
20+
client = _client(monkeypatch)
21+
resp = client.post("/ask", json={"question": "hi", "max_steps": 0})
22+
assert resp.status_code == 400
23+
resp = client.post("/ask", json={"question": "hi", "max_steps": 100})
24+
assert resp.status_code == 400
25+
26+
27+
def test_trace_id_validation(monkeypatch):
28+
client = _client(monkeypatch)
29+
resp = client.get("/trace/not-a-uuid")
30+
assert resp.status_code == 400

0 commit comments

Comments
 (0)