11from __future__ import annotations
22from fastapi import FastAPI , HTTPException
33from fastapi .middleware .cors import CORSMiddleware
4- import os , json
4+ from fastapi .encoders import jsonable_encoder
5+ import os , json , re
6+ from threading import Lock
57from pydantic import BaseModel
68from .costs import estimate_prediction_cost
79from importlib .metadata import version as _pkg_version , PackageNotFoundError
@@ -35,6 +37,8 @@ class AskResponse(BaseModel):
3537setup_logging ()
3638configure_lm ()
3739_agent = MicroAgent ()
40+ _agent_lock = Lock ()
41+ _serialize = os .getenv ("MICRO_AGENT_SERIALIZE" , "1" ).strip ().lower () not in {"0" , "false" , "no" , "off" }
3842
3943@app .get ("/healthz" )
4044def healthz ():
@@ -52,16 +56,33 @@ def health():
5256
5357@app .post ("/ask" , response_model = AskResponse )
5458def ask (req : AskRequest ):
55- agent = _agent if req .use_tool_calls is None and req .max_steps == _agent .max_steps else MicroAgent (max_steps = req .max_steps , use_tool_calls = req .use_tool_calls )
56- pred = agent (req .question )
59+ question = (req .question or "" ).strip ()
60+ if not question :
61+ raise HTTPException (status_code = 400 , detail = "question must be a non-empty string" )
62+ if req .max_steps < 1 or req .max_steps > 20 :
63+ raise HTTPException (status_code = 400 , detail = "max_steps must be between 1 and 20" )
64+
65+ def _call_agent ():
66+ agent = _agent if req .use_tool_calls is None and req .max_steps == _agent .max_steps else MicroAgent (max_steps = req .max_steps , use_tool_calls = req .use_tool_calls )
67+ return agent (question )
68+
69+ if _serialize :
70+ with _agent_lock :
71+ pred = _call_agent ()
72+ else :
73+ pred = _call_agent ()
74+
5775 trace_id = new_trace_id ()
5876 usage = getattr (pred , "usage" , {}) or {}
59- est = estimate_prediction_cost (req .question , pred .trace , pred .answer , usage )
60- path = dump_trace (trace_id , req .question , pred .trace , pred .answer , usage = usage , cost_usd = est .get ("cost_usd" ))
61- return AskResponse (answer = pred .answer , trace_id = trace_id , trace_path = path , steps = pred .trace , usage = usage , cost_usd = est .get ("cost_usd" ))
77+ est = estimate_prediction_cost (question , pred .trace , pred .answer , usage )
78+ path = dump_trace (trace_id , question , pred .trace , pred .answer , usage = usage , cost_usd = est .get ("cost_usd" ))
79+ steps = jsonable_encoder (pred .trace )
80+ return AskResponse (answer = pred .answer , trace_id = trace_id , trace_path = path , steps = steps , usage = usage , cost_usd = est .get ("cost_usd" ))
6281
6382@app .get ("/trace/{trace_id}" )
6483def get_trace (trace_id : str ):
84+ if not re .fullmatch (r"[0-9a-f]{32}" , trace_id , flags = re .IGNORECASE ):
85+ raise HTTPException (status_code = 400 , detail = "Invalid trace id format" )
6586 traces_dir = os .getenv ("TRACES_DIR" , "traces" )
6687 path = os .path .join (traces_dir , f"{ trace_id } .jsonl" )
6788 if not os .path .exists (path ):
0 commit comments