diff --git a/cookbook/05_agent_os/client/10_sse_reconnect.py b/cookbook/05_agent_os/client/10_sse_reconnect.py new file mode 100644 index 0000000000..a68f91d237 --- /dev/null +++ b/cookbook/05_agent_os/client/10_sse_reconnect.py @@ -0,0 +1,229 @@ +""" +SSE Reconnection +===================== + +Tests SSE stream reconnection for agent runs using background=True, stream=True. +When background=True, the agent runs in a detached task that survives client +disconnections. Events are buffered so the client can reconnect via /resume. + +Steps: +1. Start a streaming run with background=true +2. Disconnect after a few events +3. Reconnect via /resume and catch up on missed events + +Prerequisites: +1. Start the AgentOS server: python cookbook/05_agent_os/basic.py +2. Run this script: python cookbook/05_agent_os/client/10_sse_reconnect.py +""" + +import asyncio +import json +from typing import Optional + +import httpx + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- +BASE_URL = "http://localhost:7777" +# Number of events to receive before simulating a disconnect +EVENTS_BEFORE_DISCONNECT = 6 +# How long to "stay disconnected" (seconds) +DISCONNECT_DURATION = 3 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +def parse_sse_line(line: str) -> Optional[dict]: + """Parse a single SSE data line into a dict.""" + if line.startswith("data: "): + try: + return json.loads(line[6:]) + except json.JSONDecodeError: + return None + return None + + +# --------------------------------------------------------------------------- +# Test +# --------------------------------------------------------------------------- +async def test_sse_reconnection(): + print("=" * 70) + print("Agent SSE Reconnection Test") + print("=" * 70) + + # Step 1: Discover an agent + async with httpx.AsyncClient(base_url=BASE_URL, timeout=30) as client: + resp = await client.get("/agents") + resp.raise_for_status() + agents = resp.json() + if not agents: + print("[ERROR] No agents available on the server") + return + agent_id = agents[0]["id"] + print(f"Using agent: {agent_id} ({agents[0].get('name', 'unnamed')})") + + # Step 2: Start a streaming run and disconnect after a few events + run_id: Optional[str] = None + session_id: Optional[str] = None + last_event_index: Optional[int] = None + events_phase1: list[dict] = [] + + print( + f"\nPhase 1: Starting SSE stream, will disconnect after {EVENTS_BEFORE_DISCONNECT} events..." + ) + + async with httpx.AsyncClient(base_url=BASE_URL, timeout=60) as client: + form_data = { + "message": "Tell me a detailed story about a brave knight who goes on a quest. Make it at least 5 paragraphs long.", + "stream": "true", + "background": "true", + } + async with client.stream( + "POST", f"/agents/{agent_id}/runs", data=form_data + ) as response: + event_count = 0 + buffer = "" + async for chunk in response.aiter_text(): + buffer += chunk + # SSE events are delimited by double newlines + while "\n\n" in buffer: + event_str, buffer = buffer.split("\n\n", 1) + for line in event_str.strip().split("\n"): + data = parse_sse_line(line) + if data is None: + continue + + event_type = data.get("event", "unknown") + ev_idx = data.get("event_index") + ev_run_id = data.get("run_id") + ev_session_id = data.get("session_id") + + # Track run_id and session_id + if ev_run_id and not run_id: + run_id = ev_run_id + if ev_session_id and not session_id: + session_id = ev_session_id + if ev_idx is not None: + last_event_index = ev_idx + + events_phase1.append(data) + event_count += 1 + content_preview = str(data.get("content", ""))[:60] + print( + f" [{event_count}] event={event_type} index={ev_idx} content={content_preview!r}" + ) + + if event_count >= EVENTS_BEFORE_DISCONNECT: + break + if event_count >= EVENTS_BEFORE_DISCONNECT: + break + if event_count >= EVENTS_BEFORE_DISCONNECT: + break + + print( + f"\n[DISCONNECT] Received {event_count} events. run_id={run_id}, last_event_index={last_event_index}" + ) + + if not run_id: + print("[ERROR] Could not determine run_id from events") + return + + # Step 3: Wait (simulate user being away) + print(f"\nSimulating disconnect for {DISCONNECT_DURATION} seconds...") + await asyncio.sleep(DISCONNECT_DURATION) + + # Step 4: Resume via /resume endpoint + print("\nPhase 2: Reconnecting via /resume endpoint...") + events_phase2: list[dict] = [] + + form_data: dict = {} + if last_event_index is not None: + form_data["last_event_index"] = str(last_event_index) + if session_id: + form_data["session_id"] = session_id + + async with httpx.AsyncClient(base_url=BASE_URL, timeout=120) as client: + async with client.stream( + "POST", f"/agents/{agent_id}/runs/{run_id}/resume", data=form_data + ) as response: + buffer = "" + async for chunk in response.aiter_text(): + buffer += chunk + while "\n\n" in buffer: + event_str, buffer = buffer.split("\n\n", 1) + for line in event_str.strip().split("\n"): + data = parse_sse_line(line) + if data is None: + continue + + event_type = data.get("event", "unknown") + ev_idx = data.get("event_index") + events_phase2.append(data) + + if event_type in ("catch_up", "replay", "subscribed"): + print( + f" [META] event={event_type} | {json.dumps(data, indent=2)}" + ) + else: + content_preview = str(data.get("content", ""))[:60] + print( + f" [RESUME] event={event_type} index={ev_idx} content={content_preview!r}" + ) + + # Step 5: Print summary + print("\n" + "=" * 70) + print("Summary") + print("=" * 70) + print(f"Phase 1 events received: {len(events_phase1)}") + print(f"Phase 2 events received: {len(events_phase2)}") + + # Check for meta events + meta_events = [ + e + for e in events_phase2 + if e.get("event") in ("catch_up", "replay", "subscribed") + ] + data_events = [ + e + for e in events_phase2 + if e.get("event") not in ("catch_up", "replay", "subscribed", "error") + ] + print(f" Meta events (catch_up/replay/subscribed): {len(meta_events)}") + print(f" Data events (actual agent events): {len(data_events)}") + + # Validate event_index continuity + phase1_indices = [ + e.get("event_index") for e in events_phase1 if e.get("event_index") is not None + ] + phase2_indices = [ + e.get("event_index") for e in data_events if e.get("event_index") is not None + ] + + if phase1_indices and phase2_indices: + last_p1 = max(phase1_indices) + first_p2 = min(phase2_indices) + last_p2 = max(phase2_indices) + print(f"\n Phase 1 event_index range: 0 -> {last_p1}") + print(f" Phase 2 event_index range: {first_p2} -> {last_p2}") + if first_p2 == last_p1 + 1: + print(" [PASS] Event indices are contiguous - no events were lost") + elif first_p2 > last_p1: + print(f" [WARN] Gap in event indices: {last_p1} -> {first_p2}") + else: + print(" [INFO] Overlapping indices detected (dedup may have occurred)") + elif not phase2_indices: + print( + "\n [INFO] No data events in phase 2 (run may have completed before resume)" + ) + else: + print("\n [INFO] No event indices in phase 1 to compare") + + total_events = len(events_phase1) + len(data_events) + print(f"\n Total unique events across both phases: {total_events}") + print("=" * 70) + + +if __name__ == "__main__": + asyncio.run(test_sse_reconnection()) diff --git a/libs/agno/agno/agent/_run.py b/libs/agno/agno/agent/_run.py index 2e5b5af833..71ddef9e9b 100644 --- a/libs/agno/agno/agent/_run.py +++ b/libs/agno/agno/agent/_run.py @@ -1884,6 +1884,144 @@ async def _background_task() -> None: return run_response +async def _arun_background_stream( + agent: Agent, + run_response: RunOutput, + run_context: RunContext, + session_id: str, + user_id: Optional[str] = None, + add_history_to_context: Optional[bool] = None, + add_dependencies_to_context: Optional[bool] = None, + add_session_state_to_context: Optional[bool] = None, + response_format: Optional[Union[Dict, Type[BaseModel]]] = None, + stream_events: bool = False, + yield_run_output: Optional[bool] = None, + debug_mode: Optional[bool] = None, + background_tasks: Optional[Any] = None, + **kwargs: Any, +) -> AsyncIterator[str]: + """Background streaming agent run that survives client disconnections. + + 1. Persists RUNNING status in DB + 2. Spawns a detached asyncio.Task that runs _arun_stream + 3. Buffers events (via event_buffer) and publishes to SSE subscribers + 4. Yields SSE-formatted strings via an asyncio.Queue + + The detached task keeps running even if the client disconnects. + The caller (router) just yields the SSE strings to the client. + + Similar to how Workflow._arun_background_stream handles WebSocket streaming, + but uses SSE transport with event_buffer and sse_subscriber_manager. + """ + from agno.agent._session import asave_session + from agno.agent._storage import aread_or_create_session, update_metadata + + run_id = run_response.run_id + if not run_id: + raise ValueError("run_id is required for background streaming") + + # 1. Persist RUNNING status so the run is visible in the DB immediately + run_response.status = RunStatus.running + + agent_session = await aread_or_create_session(agent, session_id=session_id, user_id=user_id) + update_metadata(agent, session=agent_session) + agent_session.upsert_run(run=run_response) + await asave_session(agent, session=agent_session) + + log_info(f"Background stream run {run_id} persisted with RUNNING status") + + # 2. Create queue for forwarding SSE strings to the caller + sse_queue: asyncio.Queue[Optional[str]] = asyncio.Queue() + + # 3. Spawn detached background task + async def _background_producer() -> None: + try: + async for event in _arun_stream( + agent, + run_response=run_response, + run_context=run_context, + user_id=user_id, + response_format=response_format, + stream_events=stream_events, + yield_run_output=yield_run_output, + session_id=session_id, + add_history_to_context=add_history_to_context, + add_dependencies_to_context=add_dependencies_to_context, + add_session_state_to_context=add_session_state_to_context, + debug_mode=debug_mode, + background_tasks=background_tasks, + pre_session=agent_session, + **kwargs, + ): + if isinstance(event, RunOutput): + continue + + # Buffer event for reconnection support + event_index: Optional[int] = None + try: + from agno.os.managers import event_buffer + + event_index = event_buffer.add_event(run_id, event) + except Exception: + pass + + # Format as SSE + from agno.os.utils import format_sse_event_with_index + + sse_data = format_sse_event_with_index(event, event_index=event_index, run_id=run_id) + + # Push to primary queue (original client) + try: + await sse_queue.put(sse_data) + except Exception: + pass + + # Publish to SSE subscribers (resumed clients) + try: + from agno.os.managers import sse_subscriber_manager + + await sse_subscriber_manager.publish(run_id, sse_data) + except Exception: + pass + + except Exception: + log_error(f"Background stream run {run_id} failed", exc_info=True) + + finally: + # Mark run completed in event buffer (status is set by _arun_stream/acleanup_and_store) + try: + from agno.os.managers import event_buffer + + event_buffer.set_run_completed(run_id, run_response.status or RunStatus.completed) + except Exception: + pass + + # Signal SSE subscribers that run is done + try: + from agno.os.managers import sse_subscriber_manager + + await sse_subscriber_manager.complete(run_id) + except Exception: + pass + + # Signal primary queue that run is done + try: + await sse_queue.put(None) + except Exception: + pass + + task = asyncio.create_task(_background_producer()) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) + + # 4. Yield SSE strings from the queue + while True: + sse_data = await sse_queue.get() + if sse_data is None: + break + yield sse_data + + async def _arun_stream( agent: Agent, run_response: RunOutput, @@ -2572,16 +2710,30 @@ def arun_dispatch( # type: ignore run_response.metrics = RunMetrics() run_response.metrics.start_timer() - # Background execution: return immediately with PENDING status + # Background execution if background: - if opts.stream: - raise ValueError( - "Background execution cannot be combined with streaming. Set stream=False when using background=True." - ) if not agent.db: raise ValueError( "Background execution requires a database to be configured on the agent for run persistence." ) + if opts.stream: + # background=True, stream=True: run in background task, stream events via queue + return _arun_background_stream( # type: ignore[return-value] + agent, + run_response=run_response, + run_context=run_context, + user_id=user_id, + response_format=response_format, + stream_events=opts.stream_events, + yield_run_output=opts.yield_run_output, + session_id=session_id, + add_history_to_context=opts.add_history_to_context, + add_dependencies_to_context=opts.add_dependencies_to_context, + add_session_state_to_context=opts.add_session_state_to_context, + debug_mode=debug_mode, + background_tasks=background_tasks, + **kwargs, + ) return _arun_background( # type: ignore[return-value] agent, run_response=run_response, diff --git a/libs/agno/agno/os/managers.py b/libs/agno/agno/os/managers.py index 427c30fff8..81466a318c 100644 --- a/libs/agno/agno/os/managers.py +++ b/libs/agno/agno/os/managers.py @@ -5,10 +5,12 @@ - WebSocketManager: WebSocket connection management for real-time streaming - EventsBuffer: Event buffering for agent/team/workflow reconnection support - WebSocketHandler: Handler for sending events over WebSocket connections +- SSESubscriberManager: Subscriber management for SSE-based reconnection These managers are used by agents, teams, and workflows for background WebSocket execution. """ +import asyncio import json from dataclasses import dataclass from time import time @@ -315,6 +317,58 @@ def get_run_status(self, run_id: str) -> Optional[RunStatus]: return metadata["status"] if metadata else None +class SSESubscriberManager: + """ + Manages asyncio.Queue subscribers for SSE-based reconnection. + + When a client reconnects to a still-running agent/team via the /resume SSE endpoint, + it registers a Queue here. The response streamer pushes SSE-formatted events to all + registered queues. A None sentinel signals run completion. + """ + + def __init__(self) -> None: + self._subscribers: Dict[str, List[asyncio.Queue[Optional[str]]]] = {} + + def subscribe(self, run_id: str) -> "asyncio.Queue[Optional[str]]": + """Register a new subscriber queue for a run. Returns the queue.""" + if run_id not in self._subscribers: + self._subscribers[run_id] = [] + queue: asyncio.Queue[Optional[str]] = asyncio.Queue() + self._subscribers[run_id].append(queue) + log_debug(f"SSE subscriber registered for run {run_id}") + return queue + + def unsubscribe(self, run_id: str, queue: "asyncio.Queue[Optional[str]]") -> None: + """Remove a subscriber queue.""" + if run_id in self._subscribers: + try: + self._subscribers[run_id].remove(queue) + except ValueError: + pass + if not self._subscribers[run_id]: + del self._subscribers[run_id] + + async def publish(self, run_id: str, sse_data: str) -> None: + """Push an SSE-formatted event string to all subscriber queues for a run.""" + if run_id not in self._subscribers: + return + for queue in self._subscribers[run_id]: + try: + await queue.put(sse_data) + except Exception: + pass + + async def complete(self, run_id: str) -> None: + """Signal all subscribers that the run is done by pushing None sentinel.""" + if run_id not in self._subscribers: + return + for queue in self._subscribers[run_id]: + try: + await queue.put(None) + except Exception: + pass + + # Global manager instances websocket_manager = WebSocketManager( active_connections={}, @@ -324,3 +378,5 @@ def get_run_status(self, run_id: str) -> Optional[RunStatus]: max_events_per_run=10000, # Keep last 10000 events per run cleanup_interval=1800, # Clean up completed runs after 30 minutes ) + +sse_subscriber_manager = SSESubscriberManager() diff --git a/libs/agno/agno/os/routers/agents/router.py b/libs/agno/agno/os/routers/agents/router.py index ad1b337ca1..b21f0b0059 100644 --- a/libs/agno/agno/os/routers/agents/router.py +++ b/libs/agno/agno/os/routers/agents/router.py @@ -27,6 +27,7 @@ require_approval_resolved, require_resource_access, ) +from agno.os.managers import event_buffer, sse_subscriber_manager from agno.os.routers.agents.schema import AgentResponse from agno.os.schema import ( BadRequestResponse, @@ -49,6 +50,7 @@ from agno.run.agent import RunErrorEvent, RunOutput from agno.run.base import RunStatus from agno.utils.log import log_debug, log_error, log_warning +from agno.utils.serialize import json_serializer if TYPE_CHECKING: from agno.os.app import AgentOS @@ -67,8 +69,8 @@ async def agent_response_streamer( auth_token: Optional[str] = None, **kwargs: Any, ) -> AsyncGenerator: + """Default SSE generator. Agent runs inline — if client disconnects, agent is cancelled.""" try: - # Pass background_tasks if provided if background_tasks is not None: kwargs["background_tasks"] = background_tasks @@ -77,7 +79,6 @@ async def agent_response_streamer( else: stream_events = True - # Pass auth_token for remote agents if auth_token and isinstance(agent, RemoteAgent): kwargs["auth_token"] = auth_token @@ -113,6 +114,55 @@ async def agent_response_streamer( yield format_sse_event(error_response) +async def agent_resumable_response_streamer( + agent: Union[Agent, RemoteAgent], + message: str, + session_id: Optional[str] = None, + user_id: Optional[str] = None, + images: Optional[List[Image]] = None, + audio: Optional[List[Audio]] = None, + videos: Optional[List[Video]] = None, + files: Optional[List[FileMedia]] = None, + background_tasks: Optional[BackgroundTasks] = None, + auth_token: Optional[str] = None, + **kwargs: Any, +) -> AsyncGenerator: + """Resumable SSE generator for background=True, stream=True. + + Delegates to agent.arun(background=True, stream=True) which handles: + - Persisting RUNNING status in DB + - Running agent in a detached asyncio.Task (survives client disconnect) + - Buffering events for reconnection via /resume + - Publishing to SSE subscribers for resumed clients + - Yielding SSE-formatted strings via a queue + """ + if background_tasks is not None: + kwargs["background_tasks"] = background_tasks + + if "stream_events" in kwargs: + stream_events = kwargs.pop("stream_events") + else: + stream_events = True + + if auth_token and isinstance(agent, RemoteAgent): + kwargs["auth_token"] = auth_token + + async for sse_data in agent.arun( + input=message, + session_id=session_id, + user_id=user_id, + images=images, + audio=audio, + videos=videos, + files=files, + stream=True, + stream_events=stream_events, + background=True, + **kwargs, + ): + yield sse_data + + async def agent_continue_response_streamer( agent: Union[Agent, RemoteAgent], run_id: str, @@ -122,8 +172,8 @@ async def agent_continue_response_streamer( background_tasks: Optional[BackgroundTasks] = None, auth_token: Optional[str] = None, ) -> AsyncGenerator: + """Default SSE generator for continue_run. Agent runs inline — client disconnect cancels agent.""" try: - # Build kwargs for remote agent auth extra_kwargs: dict = {} if auth_token and isinstance(agent, RemoteAgent): extra_kwargs["auth_token"] = auth_token @@ -148,7 +198,6 @@ async def agent_continue_response_streamer( additional_data=e.additional_data, ) yield format_sse_event(error_response) - except Exception as e: import traceback @@ -159,8 +208,175 @@ async def agent_continue_response_streamer( error_id=e.error_id if hasattr(e, "error_id") else None, ) yield format_sse_event(error_response) + + +async def _resume_stream_generator( + agent: Union[Agent, RemoteAgent], + run_id: str, + last_event_index: Optional[int], + session_id: Optional[str], +) -> AsyncGenerator: + """SSE generator for the /resume endpoint. + + Three reconnection paths: + 1. Run still active (in buffer): replay missed events + subscribe for live events via Queue + 2. Run completed (in buffer): replay all events since last_event_index + 3. Not in buffer: fall back to database replay + """ + buffer_status = event_buffer.get_run_status(run_id) + + if buffer_status is None: + # PATH 3: Not in buffer -- fall back to database + if session_id and not isinstance(agent, RemoteAgent): + run_output = await agent.aget_run_output(run_id=run_id, session_id=session_id) + if run_output and run_output.events: + meta: dict = { + "event": "replay", + "run_id": run_id, + "status": run_output.status.value if run_output.status else "unknown", + "total_events": len(run_output.events), + "message": "Run completed. Replaying all events from database.", + } + yield f"event: replay\ndata: {json.dumps(meta)}\n\n" + + for idx, event in enumerate(run_output.events): + event_dict = event.to_dict() + event_dict["event_index"] = idx + if "run_id" not in event_dict: + event_dict["run_id"] = run_id + event_type = event_dict.get("event", "message") + yield f"event: {event_type}\ndata: {json.dumps(event_dict, separators=(',', ':'), default=json_serializer, ensure_ascii=False)}\n\n" + return + elif run_output: + meta = { + "event": "replay", + "run_id": run_id, + "status": run_output.status.value if run_output.status else "unknown", + "total_events": 0, + "message": "Run completed but no events stored.", + } + yield f"event: replay\ndata: {json.dumps(meta)}\n\n" + return + + # Run not found anywhere + error = {"event": "error", "error": f"Run {run_id} not found in buffer or database"} + yield f"event: error\ndata: {json.dumps(error)}\n\n" + return + + if buffer_status in (RunStatus.completed, RunStatus.error, RunStatus.cancelled, RunStatus.paused): + # PATH 2: Run finished -- replay missed events from buffer + total_buffered = event_buffer.get_event_count(run_id) + missed_events = event_buffer.get_events(run_id, last_event_index=last_event_index) + log_debug( + f"Resume PATH 2: run_id={run_id}, status={buffer_status.value}, " + f"last_event_index={last_event_index}, total_buffered={total_buffered}, " + f"missed_events={len(missed_events)}" + ) + + meta = { + "event": "replay", + "run_id": run_id, + "status": buffer_status.value, + "total_events": len(missed_events), + "total_buffered": total_buffered, + "last_event_index_requested": last_event_index if last_event_index is not None else -1, + "message": f"Run {buffer_status.value}. Replaying {len(missed_events)} missed events (of {total_buffered} total).", + } + yield f"event: replay\ndata: {json.dumps(meta)}\n\n" + + start_index = (last_event_index + 1) if last_event_index is not None else 0 + for idx, buffered_event in enumerate(missed_events): + event_dict = buffered_event.to_dict() + event_dict["event_index"] = start_index + idx + if "run_id" not in event_dict: + event_dict["run_id"] = run_id + event_type = event_dict.get("event", "message") + yield f"event: {event_type}\ndata: {json.dumps(event_dict, separators=(',', ':'), default=json_serializer, ensure_ascii=False)}\n\n" return + # PATH 1: Run still active -- subscribe FIRST (to avoid race condition), then replay missed events + queue = sse_subscriber_manager.subscribe(run_id) + + try: + missed_events = event_buffer.get_events(run_id, last_event_index) + current_count = event_buffer.get_event_count(run_id) + + # Track the highest replayed event_index for dedup against queue events + last_replayed_index = last_event_index if last_event_index is not None else -1 + + if missed_events: + meta = { + "event": "catch_up", + "run_id": run_id, + "status": "running", + "missed_events": len(missed_events), + "current_event_count": current_count, + "message": f"Catching up on {len(missed_events)} missed events.", + } + yield f"event: catch_up\ndata: {json.dumps(meta)}\n\n" + + start_index = (last_event_index + 1) if last_event_index is not None else 0 + for idx, buffered_event in enumerate(missed_events): + current_idx = start_index + idx + event_dict = buffered_event.to_dict() + event_dict["event_index"] = current_idx + if "run_id" not in event_dict: + event_dict["run_id"] = run_id + event_type = event_dict.get("event", "message") + yield f"event: {event_type}\ndata: {json.dumps(event_dict, separators=(',', ':'), default=json_serializer, ensure_ascii=False)}\n\n" + last_replayed_index = current_idx + + # Re-check buffer status after subscribing: the run may have completed + # between our initial status check and now. If so, replay remaining events + # from buffer instead of waiting on the queue (the sentinel was already pushed + # before our subscription existed). + updated_status = event_buffer.get_run_status(run_id) + if updated_status is not None and updated_status != RunStatus.running: + # Run completed while we were catching up -- replay remaining from buffer + remaining = event_buffer.get_events(run_id, last_event_index=last_replayed_index) + if remaining: + replay_start = last_replayed_index + 1 + for idx, buffered_event in enumerate(remaining): + current_idx = replay_start + idx + event_dict = buffered_event.to_dict() + event_dict["event_index"] = current_idx + if "run_id" not in event_dict: + event_dict["run_id"] = run_id + event_type = event_dict.get("event", "message") + yield f"event: {event_type}\ndata: {json.dumps(event_dict, separators=(',', ':'), default=json_serializer, ensure_ascii=False)}\n\n" + return + + # Confirm subscription for live events + subscribed = { + "event": "subscribed", + "run_id": run_id, + "status": "running", + "current_event_count": current_count, + "message": "Subscribed to agent run. Receiving live events.", + } + yield f"event: subscribed\ndata: {json.dumps(subscribed)}\n\n" + + log_debug(f"SSE client subscribed to agent run {run_id} (last_event_index: {last_event_index})") + + # Read from queue, dedup events already replayed by event_index + while True: + sse_data = await queue.get() + if sse_data is None: + # Sentinel: run completed + break + # Dedup: extract event_index from the SSE data and skip if already replayed + try: + data_line = sse_data.split("data: ", 1)[1].split("\n\n")[0] + parsed = json.loads(data_line) + ev_idx = parsed.get("event_index") + if ev_idx is not None and ev_idx <= last_replayed_index: + continue + except Exception: + pass + yield sse_data + finally: + sse_subscriber_manager.unsubscribe(run_id, queue) + def get_agent_router( os: "AgentOS", @@ -357,10 +573,33 @@ async def create_agent_run( # Extract auth token for remote agents auth_token = get_auth_token_from_request(request) - # Background execution: return 202 immediately with run metadata + # Background execution if background: if isinstance(agent, RemoteAgent): raise HTTPException(status_code=400, detail="Background execution is not supported for remote agents") + + if stream: + # background=True, stream=True: resumable SSE streaming + # Agent runs in a detached asyncio.Task that survives client disconnections. + # Events are buffered for reconnection via /resume endpoint. + return StreamingResponse( + agent_resumable_response_streamer( + agent, + message, + session_id=session_id, + user_id=user_id, + images=base64_images if base64_images else None, + audio=base64_audios if base64_audios else None, + videos=base64_videos if base64_videos else None, + files=input_files if input_files else None, + background_tasks=background_tasks, + auth_token=auth_token, + **kwargs, + ), + media_type="text/event-stream", + ) + + # background=True, stream=False: return 202 immediately with run metadata if not agent.db: raise HTTPException( status_code=400, detail="Background execution requires a database to be configured on the agent" @@ -763,6 +1002,50 @@ async def get_agent_run( return run_output.to_dict() + @router.post( + "/agents/{agent_id}/runs/{run_id}/resume", + tags=["Agents"], + operation_id="resume_agent_run_stream", + summary="Resume Agent Run Stream", + description=( + "Resume an SSE stream for an agent run after disconnection.\n\n" + "Sends missed events since `last_event_index`, then continues streaming " + "live events if the run is still active.\n\n" + "**Three reconnection paths:**\n" + "1. **Run still active**: Sends catch-up events + continues live streaming\n" + "2. **Run completed (in buffer)**: Replays missed buffered events\n" + "3. **Run completed (in database)**: Replays events from database\n\n" + "**Client usage:**\n" + "Track `event_index` from each SSE event. On reconnection, pass the last " + "received `event_index` as `last_event_index`." + ), + responses={ + 200: { + "description": "SSE stream of catch-up and/or live events", + "content": {"text/event-stream": {}}, + }, + 400: {"description": "Not supported for remote agents", "model": BadRequestResponse}, + 404: {"description": "Agent not found", "model": NotFoundResponse}, + }, + dependencies=[Depends(require_resource_access("agents", "run", "agent_id"))], + ) + async def resume_agent_run_stream( + agent_id: str, + run_id: str, + last_event_index: Optional[int] = Form(None, description="Index of last event received by client (0-based)"), + session_id: Optional[str] = Form(None, description="Session ID for database fallback"), + ): + agent = get_agent_by_id(agent_id=agent_id, agents=os.agents, db=os.db, registry=os.registry, create_fresh=True) + if agent is None: + raise HTTPException(status_code=404, detail="Agent not found") + if isinstance(agent, RemoteAgent): + raise HTTPException(status_code=400, detail="Stream resumption is not supported for remote agents") + + return StreamingResponse( + _resume_stream_generator(agent, run_id, last_event_index, session_id), + media_type="text/event-stream", + ) + @router.get( "/agents/{agent_id}/runs", tags=["Agents"], diff --git a/libs/agno/agno/os/utils.py b/libs/agno/agno/os/utils.py index 6b76ca181a..2cab14ee01 100644 --- a/libs/agno/agno/os/utils.py +++ b/libs/agno/agno/os/utils.py @@ -199,6 +199,42 @@ def format_sse_event(event: Union[RunOutputEvent, TeamRunOutputEvent, WorkflowRu return f"event: message\ndata: {clean_json}\n\n" +def format_sse_event_with_index( + event: Union[RunOutputEvent, TeamRunOutputEvent, WorkflowRunOutputEvent], + event_index: Optional[int] = None, + run_id: Optional[str] = None, +) -> str: + """Format an event as SSE with injected event_index and run_id. + + Used by the agent/team response streamers to include reconnection metadata + in SSE payloads without modifying the core event dataclasses. + + Args: + event: The event object to serialize. + event_index: Buffer index for reconnection tracking. + run_id: Run ID to inject if not already present on the event. + + Returns: + SSE-formatted string with event_index in the data payload. + """ + from agno.utils.serialize import json_serializer + + try: + event_type = event.event or "message" + event_dict = event.to_dict() + + if event_index is not None: + event_dict["event_index"] = event_index + if run_id and "run_id" not in event_dict: + event_dict["run_id"] = run_id + + clean_json = json.dumps(event_dict, separators=(",", ":"), default=json_serializer, ensure_ascii=False) + return f"event: {event_type}\ndata: {clean_json}\n\n" + except Exception: + clean_json = event.to_json(separators=(",", ":"), indent=None) + return f"event: message\ndata: {clean_json}\n\n" + + async def get_db( dbs: dict[str, list[Union[BaseDb, AsyncBaseDb, RemoteDb]]], db_id: Optional[str] = None, table: Optional[str] = None ) -> Union[BaseDb, AsyncBaseDb, RemoteDb]: diff --git a/libs/agno/tests/unit/agent/test_background_execution.py b/libs/agno/tests/unit/agent/test_background_execution.py index 8d10233d9b..3523936ad8 100644 --- a/libs/agno/tests/unit/agent/test_background_execution.py +++ b/libs/agno/tests/unit/agent/test_background_execution.py @@ -102,12 +102,13 @@ def test_cleanup_removes_cancel_intent(self): class TestBackgroundValidation: - def test_background_with_stream_raises_value_error(self, monkeypatch: pytest.MonkeyPatch): - """Background execution cannot be combined with streaming.""" + def test_background_with_stream_requires_db(self, monkeypatch: pytest.MonkeyPatch): + """Background execution with streaming requires a database.""" agent = Agent(name="test-agent") + agent.db = None _patch_sync_dispatch_dependencies(agent, monkeypatch, runs=[]) - with pytest.raises(ValueError, match="Background execution cannot be combined with streaming"): + with pytest.raises(ValueError, match="Background execution requires a database"): _run.arun_dispatch(agent=agent, input="hello", stream=True, background=True) def test_background_without_db_raises_value_error(self, monkeypatch: pytest.MonkeyPatch):