-
Notifications
You must be signed in to change notification settings - Fork 324
Description
What happened?
I integrated a2a-python SDK with Smolagents framework with stream output. The agent outputs stream messages in real-time, and the messages are enqueued in TestAgentExecutor.execute. But dequeuing messages always start after finishing enqueuing all.
Here is my test.py code (Replace LLM id/key/base_url with yours):
import logging
import time
from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.apps import A2AStarletteApplication
from a2a.server.events import EventQueue
from a2a.server.request_handlers import DefaultRequestHandler
from a2a.server.tasks import InMemoryTaskStore
from a2a.types import (
AgentCapabilities,
AgentCard,
AgentSkill,
TaskState,
TaskStatus,
TaskStatusUpdateEvent,
)
from a2a.utils import new_agent_text_message, new_task
from fastapi import FastAPI
from smolagents import CodeAgent, OpenAIServerModel, WebSearchTool
from smolagents.memory import ActionStep, FinalAnswerStep
from smolagents.models import ChatMessageStreamDelta
from typing_extensions import override
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
class TestAgentExecutor(AgentExecutor):
def __init__(self):
model = OpenAIServerModel(
model_id="{YOUR_MODEL_ID}",
api_base="{YOUR_API_BASE}",
api_key="{YOUR_API_KEY}")
self.agent = CodeAgent(model=model,
tools=[WebSearchTool()],
stream_outputs=True)
@override
async def execute(
self,
context: RequestContext,
event_queue: EventQueue,
) -> None:
query = context.get_user_input()
task = context.current_task
if not context.message:
raise Exception('No message provided')
if not task:
task = new_task(context.message)
await event_queue.enqueue_event(task)
start = time.time()
try:
for message in self.agent.run(query, stream=True):
if isinstance(message, ChatMessageStreamDelta):
text = message.content
elif isinstance(message, ActionStep):
text = message.model_output
elif isinstance(message, FinalAnswerStep):
text = message.output
print(f"***{time.time() - start}: {text}")
await self._send_working_message(context, event_queue, text)
except Exception as e:
logger.error(f"Error in streaming output: {str(e)}")
finally:
await self._send_final_answer(context, event_queue)
async def _send_working_message(self, context, event_queue, text):
await event_queue.enqueue_event(
TaskStatusUpdateEvent(
append=True,
status=TaskStatus(
state=TaskState.working,
message=new_agent_text_message(
text,
context.context_id,
context.task_id,
),
),
final=False,
contextId=context.context_id,
taskId=context.task_id,
)
)
async def _send_final_answer(self, context, event_queue):
await event_queue.enqueue_event(
TaskStatusUpdateEvent(
status=TaskStatus(state=TaskState.completed),
final=True,
contextId=context.context_id,
taskId=context.task_id,
)
)
@override
async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None:
raise Exception("Cancel not supported")
if __name__ == "__main__":
agent_url = f"/api/v1/users/1/a2a/"
skill = AgentSkill(
id="assistant_agent_skill",
name="I can do everything for you",
description="I can assist you with a wide range of tasks, from chatting to calling tools.",
tags=["assistant", "chat"],
examples=["What's the weather in Shanghai today?"],
)
agent_card = AgentCard(
name="Assistant Agent",
description="Your personal assistant agent",
url="http://localhost:5000" + agent_url,
version="1.0.0",
defaultInputModes=["text"],
defaultOutputModes=["text"],
capabilities=AgentCapabilities(streaming=True),
skills=[skill],
)
request_handler = DefaultRequestHandler(
agent_executor=TestAgentExecutor(),
task_store=InMemoryTaskStore(),
)
server = A2AStarletteApplication(
agent_card=agent_card,
http_handler=request_handler,
)
app = FastAPI()
app.mount(agent_url, server.build())
import uvicorn
config = uvicorn.Config(app, host="0.0.0.0", port=5000)
server = uvicorn.Server(config)
server.run()
Run the code : python test.py 2>&1 | tee -a test.log &
Test with cURL: curl --request POST \ --url http://localhost:5000/api/v1/users/1/a2a/ \ --data '{"id":"1234","jsonrpc":"2.0","method":"message/send","params":{"message":{"role":"user","parts":[{"kind":"text","text":"推荐一部科幻电影"}],"messageId":"user123"}}}'
Symptom:
Stream outputs (log messages starting with ***) start from 1.142s, and also enqueued.
But Dequeued event (waited) of type: <class 'a2a.types.Task'> starts very late (always after Closing EventQueue) from Line4802 in the log (part of the log attached below).
Relevant log output
***1.1423838138580322: Thought
DEBUG:a2a.utils.telemetry:Start async tracer
DEBUG:a2a.server.events.event_queue:Enqueuing event of type: <class 'a2a.types.TaskStatusUpdateEvent'>
...
DEBUG:a2a.utils.telemetry:Start async tracer
DEBUG:a2a.server.events.event_queue:Enqueuing event of type: <class 'a2a.types.TaskStatusUpdateEvent'>
DEBUG:a2a.utils.telemetry:Start async tracer
DEBUG:a2a.server.events.event_queue:Enqueuing event of type: <class 'a2a.types.TaskStatusUpdateEvent'>
DEBUG:a2a.utils.telemetry:Start async tracer
DEBUG:a2a.server.events.event_queue:Closing EventQueue.
DEBUG:a2a.server.events.event_queue:Dequeued event (waited) of type: <class 'a2a.types.Task'>
DEBUG:a2a.server.events.event_consumer:Dequeued event of type: <class 'a2a.types.Task'> in consume_all.
DEBUG:a2a.server.events.event_queue:Marking task as done in EventQueue.
DEBUG:a2a.server.events.event_consumer:Marked task as done in event queue in consume_all
DEBUG:a2a.server.tasks.task_manager:Processing save of task event of type Task for task_id: 2aee89ae-96eb-4294-be25-3b761d7bbf75
DEBUG:a2a.server.tasks.task_manager:Saving task with id: 2aee89ae-96eb-4294-be25-3b761d7bbf75
DEBUG:a2a.server.tasks.inmemory_task_store:Task 2aee89ae-96eb-4294-be25-3b761d7bbf75 saved successfully.Code of Conduct
- I agree to follow this project's Code of Conduct