Skip to content

Commit ad259a6

Browse files
authored
enhancement: add supervisor agent for routing query (#42)
* adding supervisor agent * add supervisor route * remove edit button * remove edit button * first send to supervisor endpoint * bind 8005 to supervisor * refactor logic into classes * require json output for supervisor agent * break apart supervisor logic * superviosr logic simplification * add polling to help support stream * move logging process out of main * remove old logic * update fallback logic to extract yaml * update greeting msg to a const * show immediate agent output * add dynamic dots * remove static dots
1 parent 4dcfa24 commit ad259a6

File tree

13 files changed

+1564
-296
lines changed

13 files changed

+1564
-296
lines changed

api/main.py

Lines changed: 184 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from datetime import datetime
1212
from api.ai_agent import MaestroBuilderAgent
1313
from api.database import Database
14+
from api.supervisor import SupervisorAgent, Intent
1415
import uuid
1516
import subprocess
1617
import tempfile
@@ -21,6 +22,7 @@
2122
import asyncio
2223
from pathlib import Path
2324
import httpx
25+
from concurrent.futures import ThreadPoolExecutor
2426

2527
# Initialize FastAPI app
2628
app = FastAPI(
@@ -104,6 +106,27 @@ class ValidateYamlResponse(BaseModel):
104106
message: str
105107
errors: List[str] = []
106108

109+
# Status tracking for frontend updates
110+
status_updates = {}
111+
last_sent_index = {}
112+
request_results = {}
113+
114+
def create_status_logger(chat_id: str):
115+
"""Create a logger function that tracks status updates for the frontend."""
116+
def log_status(message: str, level: str = "info"):
117+
if chat_id not in status_updates:
118+
status_updates[chat_id] = []
119+
update = {
120+
"message": message,
121+
"level": level,
122+
"timestamp": datetime.now().isoformat()
123+
}
124+
status_updates[chat_id].append(update)
125+
return log_status
126+
127+
supervisor_agent = SupervisorAgent()
128+
executor = ThreadPoolExecutor(max_workers=4)
129+
107130
# ---------------------------------------
108131
# Service Functions
109132
# ---------------------------------------
@@ -302,7 +325,7 @@ def to_line(obj: Dict[str, Any]) -> bytes:
302325
await asyncio.sleep(0)
303326
yield to_line({"type": "status", "message": "(Planning agents)"})
304327
await asyncio.sleep(0)
305-
yield to_line({"type": "status", "message": "Generating agents.yaml..."})
328+
yield to_line({"type": "status", "message": "Generating agents.yaml"})
306329
await asyncio.sleep(0)
307330

308331
agents_output, agents_yaml = await generate_agents_yaml(message.content)
@@ -346,7 +369,7 @@ def to_line(obj: Dict[str, Any]) -> bytes:
346369

347370
yield to_line({"type": "status", "message": "(Building workflow prompt)"})
348371
await asyncio.sleep(0)
349-
yield to_line({"type": "status", "message": "Generating workflow.yaml..."})
372+
yield to_line({"type": "status", "message": "Generating workflow.yaml"})
350373
await asyncio.sleep(0)
351374

352375
async with httpx.AsyncClient(timeout=180) as client:
@@ -553,23 +576,21 @@ async def delete_chat_session(chat_id: str):
553576

554577
@app.post("/api/edit_yaml", response_model=EditYamlResponse)
555578
async def edit_yaml(request: EditYamlRequest):
579+
"""Edit YAML content based on user instruction."""
580+
if not request.yaml or not request.yaml.strip():
581+
raise HTTPException(status_code=400, detail="YAML content cannot be empty")
582+
if not request.instruction or not request.instruction.strip():
583+
raise HTTPException(status_code=400, detail="Instruction cannot be empty")
584+
if not request.file_type or request.file_type not in ["agents", "workflow"]:
585+
raise HTTPException(status_code=400, detail="File type must be 'agents' or 'workflow'")
586+
556587
try:
557-
# Build the prompt for the editing agent
558-
prompt = f"Current YAML file (type: {request.file_type}):\n{request.yaml}\n\nUser instruction: {request.instruction}\n\nPlease apply the requested edit and return only the updated YAML file."
559-
resp = requests.post(
560-
"http://localhost:8002/chat",
561-
json={"prompt": prompt},
588+
file_name = f"{request.file_type}.yaml"
589+
edited_yaml = supervisor_agent.edit_yaml(
590+
yaml_content=request.yaml,
591+
file_to_edit=file_name,
592+
instruction=request.instruction
562593
)
563-
if resp.status_code != 200:
564-
raise Exception(resp.text)
565-
edited_yaml = resp.json().get("response", "")
566-
# Remove markdown formatting if present
567-
if "```yaml" in edited_yaml:
568-
edited_yaml = (
569-
edited_yaml.split("```yaml", 1)[-1].split("```", 1)[0].strip()
570-
)
571-
elif "```" in edited_yaml:
572-
edited_yaml = edited_yaml.split("```", 1)[-1].split("```", 1)[0].strip()
573594
return {"edited_yaml": edited_yaml}
574595
except Exception as e:
575596
raise HTTPException(status_code=500, detail=f"Editing Agent failed: {e}")
@@ -578,7 +599,6 @@ async def edit_yaml(request: EditYamlRequest):
578599
@app.post("/api/validate_yaml", response_model=ValidateYamlResponse)
579600
async def validate_yaml(request: ValidateYamlRequest):
580601
try:
581-
# fix double-escaped characters
582602
import codecs
583603
unescaped_content = codecs.decode(request.yaml_content, 'unicode_escape')
584604
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as temp_file:
@@ -648,6 +668,152 @@ def strip_ansi_codes(text):
648668
)
649669

650670

671+
class SupervisorRequest(BaseModel):
672+
content: str
673+
chat_id: Optional[str] = None
674+
675+
class SupervisorResponse(BaseModel):
676+
intent: str
677+
confidence: float
678+
reasoning: str
679+
response: str
680+
yaml_files: List[Dict[str, str]]
681+
chat_id: str
682+
683+
class AsyncSupervisorResponse(BaseModel):
684+
request_id: str
685+
status: str
686+
message: str
687+
chat_id: str
688+
689+
def store_request_result(request_id: str, result):
690+
"""Store the result of a background request."""
691+
if isinstance(result, dict) and "error" not in result:
692+
supervisor_result = SupervisorResponse(
693+
intent=result["intent"],
694+
confidence=result["confidence"],
695+
reasoning=result["reasoning"],
696+
response=result["response"],
697+
yaml_files=result["yaml_files"],
698+
chat_id=result["chat_id"],
699+
)
700+
request_results[request_id] = supervisor_result
701+
else:
702+
request_results[request_id] = result
703+
704+
@app.post("/api/supervisor", response_model=SupervisorResponse)
705+
async def supervisor_route(request: SupervisorRequest):
706+
"""
707+
Synchronous supervisor endpoint that processes requests directly.
708+
For asynchronous processing with real-time updates, use /api/supervisor-async.
709+
"""
710+
if not request.content or not request.content.strip():
711+
raise HTTPException(status_code=400, detail="Request content cannot be empty")
712+
713+
chat_id = request.chat_id or str(uuid.uuid4())
714+
715+
try:
716+
result_container = {}
717+
718+
def sync_result_callback(request_id: str, result):
719+
result_container[request_id] = result
720+
721+
temp_request_id = "sync_request"
722+
supervisor_agent.process_request_in_background(
723+
temp_request_id,
724+
request.content,
725+
chat_id,
726+
create_status_logger,
727+
sync_result_callback,
728+
db
729+
)
730+
731+
result = result_container.get(temp_request_id)
732+
733+
if result and isinstance(result, dict) and "error" in result:
734+
raise HTTPException(status_code=500, detail=result["message"])
735+
736+
if isinstance(result, dict):
737+
return SupervisorResponse(
738+
intent=result["intent"],
739+
confidence=result["confidence"],
740+
reasoning=result["reasoning"],
741+
response=result["response"],
742+
yaml_files=result["yaml_files"],
743+
chat_id=result["chat_id"],
744+
)
745+
746+
return result
747+
748+
except HTTPException:
749+
raise
750+
except Exception as e:
751+
raise HTTPException(status_code=500, detail=f"Supervisor processing failed: {str(e)}")
752+
753+
@app.post("/api/supervisor-async", response_model=AsyncSupervisorResponse)
754+
async def supervisor_route_async(request: SupervisorRequest):
755+
"""
756+
Async version of supervisor endpoint that starts background processing
757+
and returns immediately with a request ID for polling.
758+
"""
759+
request_id = str(uuid.uuid4())
760+
chat_id = request.chat_id or str(uuid.uuid4())
761+
# Start background processing using the supervisor agent
762+
executor.submit(
763+
supervisor_agent.process_request_in_background,
764+
request_id,
765+
request.content,
766+
chat_id,
767+
create_status_logger,
768+
store_request_result,
769+
db
770+
)
771+
772+
return AsyncSupervisorResponse(
773+
request_id=request_id,
774+
status="processing",
775+
message="Request started, use the request_id to poll for status and results",
776+
chat_id=chat_id
777+
)
778+
779+
780+
@app.get("/api/supervisor-result/{request_id}")
781+
async def get_supervisor_result(request_id: str):
782+
"""Get the result of an async supervisor request."""
783+
if request_id in request_results:
784+
result = request_results[request_id]
785+
del request_results[request_id]
786+
return result
787+
else:
788+
return {"status": "processing", "message": "Request still in progress"}
789+
790+
791+
@app.get("/api/status/{chat_id}")
792+
async def get_status_updates(chat_id: str):
793+
"""Get status updates for a specific chat ID."""
794+
if chat_id in status_updates:
795+
all_updates = status_updates[chat_id]
796+
last_index = last_sent_index.get(chat_id, 0)
797+
new_updates = all_updates[last_index:]
798+
799+
if new_updates:
800+
last_sent_index[chat_id] = len(all_updates)
801+
return {"updates": new_updates}
802+
else:
803+
return {"updates": []}
804+
return {"updates": []}
805+
806+
807+
@app.delete("/api/status/{chat_id}")
808+
async def clear_status_updates(chat_id: str):
809+
"""Clear status updates for a specific chat ID."""
810+
if chat_id in status_updates:
811+
del status_updates[chat_id]
812+
if chat_id in last_sent_index:
813+
del last_sent_index[chat_id]
814+
return {"message": "Status updates cleared"}
815+
816+
651817
@app.get("/api/health")
652818
async def health_check():
653819
try:

0 commit comments

Comments
 (0)