|
| 1 | +import json |
| 2 | +from uuid import uuid4 |
| 3 | + |
| 4 | +import httpx |
| 5 | +from a2a.client import ( |
| 6 | + A2ACardResolver, |
| 7 | + ClientConfig, |
| 8 | + ClientFactory, |
| 9 | + Consumer, |
| 10 | +) |
| 11 | +from a2a.types import ( |
| 12 | + Message, |
| 13 | + Part, |
| 14 | + Role, |
| 15 | + TextPart, |
| 16 | + DataPart, |
| 17 | +) |
| 18 | + |
| 19 | + |
| 20 | +DEFAULT_TIMEOUT = 300 |
| 21 | + |
| 22 | + |
| 23 | +def create_message(*, role: Role = Role.user, text: str, context_id: str | None = None) -> Message: |
| 24 | + return Message( |
| 25 | + kind="message", |
| 26 | + role=role, |
| 27 | + parts=[Part(TextPart(kind="text", text=text))], |
| 28 | + message_id=uuid4().hex, |
| 29 | + context_id=context_id |
| 30 | + ) |
| 31 | + |
| 32 | +def merge_parts(parts: list[Part]) -> str: |
| 33 | + chunks = [] |
| 34 | + for part in parts: |
| 35 | + if isinstance(part.root, TextPart): |
| 36 | + chunks.append(part.root.text) |
| 37 | + elif isinstance(part.root, DataPart): |
| 38 | + chunks.append(json.dumps(part.root.data, indent=2)) |
| 39 | + return "\n".join(chunks) |
| 40 | + |
| 41 | +async def send_message(message: str, base_url: str, context_id: str | None = None, streaming=False, consumer: Consumer | None = None): |
| 42 | + """Returns dict with context_id, response and status (if exists)""" |
| 43 | + async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as httpx_client: |
| 44 | + resolver = A2ACardResolver(httpx_client=httpx_client, base_url=base_url) |
| 45 | + agent_card = await resolver.get_agent_card() |
| 46 | + config = ClientConfig( |
| 47 | + httpx_client=httpx_client, |
| 48 | + streaming=streaming, |
| 49 | + ) |
| 50 | + factory = ClientFactory(config) |
| 51 | + client = factory.create(agent_card) |
| 52 | + if consumer: |
| 53 | + await client.add_event_consumer(consumer) |
| 54 | + |
| 55 | + outbound_msg = create_message(text=message, context_id=context_id) |
| 56 | + last_event = None |
| 57 | + outputs = { |
| 58 | + "response": "", |
| 59 | + "context_id": None |
| 60 | + } |
| 61 | + |
| 62 | + # if streaming == False, only one event is generated |
| 63 | + async for event in client.send_message(outbound_msg): |
| 64 | + last_event = event |
| 65 | + |
| 66 | + match last_event: |
| 67 | + case Message() as msg: |
| 68 | + outputs["context_id"] = msg.context_id |
| 69 | + outputs["response"] += merge_parts(msg.parts) |
| 70 | + |
| 71 | + case (task, update): |
| 72 | + outputs["context_id"] = task.context_id |
| 73 | + outputs["status"] = task.status.state.value |
| 74 | + msg = task.status.message |
| 75 | + if msg: |
| 76 | + outputs["response"] += merge_parts(msg.parts) |
| 77 | + if task.artifacts: |
| 78 | + for artifact in task.artifacts: |
| 79 | + outputs["response"] += merge_parts(artifact.parts) |
| 80 | + |
| 81 | + case _: |
| 82 | + pass |
| 83 | + |
| 84 | + return outputs |
| 85 | + |
| 86 | + |
| 87 | +class Messenger: |
| 88 | + def __init__(self): |
| 89 | + self._context_ids = {} |
| 90 | + |
| 91 | + async def talk_to_agent(self, message: str, url: str, new_conversation: bool = False): |
| 92 | + """ |
| 93 | + Communicate with another agent by sending a message and receiving their response. |
| 94 | +
|
| 95 | + Args: |
| 96 | + message: The message to send to the agent |
| 97 | + url: The agent's URL endpoint |
| 98 | + new_conversation: If True, start fresh conversation; if False, continue existing conversation |
| 99 | +
|
| 100 | + Returns: |
| 101 | + str: The agent's response message |
| 102 | + """ |
| 103 | + outputs = await send_message(message=message, base_url=url, context_id=None if new_conversation else self._context_ids.get(url, None)) |
| 104 | + if outputs.get("status", "completed") != "completed": |
| 105 | + raise RuntimeError(f"{url} responded with: {outputs}") |
| 106 | + self._context_ids[url] = outputs.get("context_id", None) |
| 107 | + return outputs["response"] |
| 108 | + |
| 109 | + def reset(self): |
| 110 | + self._context_ids = {} |
0 commit comments