Skip to content

Commit aad5b78

Browse files
authored
feat: support A2A protocol (#657)
1 parent 4658750 commit aad5b78

File tree

11 files changed

+675
-3
lines changed

11 files changed

+675
-3
lines changed

.libraries-whitelist.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ psycopg
1111
pytest-postgresql
1212
python-bidi
1313
griffe
14+
types-requests
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
import json
2+
from typing import TypeVar
3+
4+
import requests
5+
from pydantic import BaseModel
6+
7+
from ragbits.agents import Agent, AgentOptions, AgentResult, ToolCallResult
8+
from ragbits.core.llms import LiteLLM
9+
from ragbits.core.options import Options
10+
from ragbits.core.prompt import ChatFormat, Prompt
11+
12+
OptionsT = TypeVar("OptionsT", bound=Options)
13+
14+
15+
class RoutingPromptInput(BaseModel):
16+
"""Represents a routing prompt input."""
17+
18+
message: str
19+
agents: list
20+
21+
22+
class ResultsSumarizationPromptInput(BaseModel):
23+
"""Represents a results summarization prompt input."""
24+
25+
message: str
26+
agent_results: list
27+
28+
29+
class RemoteAgentTask(BaseModel):
30+
"""Model representing a task for a remote agent"""
31+
32+
agent_url: str
33+
parameters: dict
34+
35+
36+
class AgentOrchestrator(Agent):
37+
"""
38+
Coordinates querying and aggregating responses from multiple remote agents
39+
using tools for routing and task execution.
40+
"""
41+
42+
def __init__(
43+
self,
44+
llm: LiteLLM,
45+
routing_prompt: type[Prompt[RoutingPromptInput, list[dict]]],
46+
results_summarization_prompt: type[Prompt[ResultsSumarizationPromptInput, list]],
47+
timeout: float = 20.0,
48+
*,
49+
history: ChatFormat | None = None,
50+
keep_history: bool = False,
51+
default_options: AgentOptions[OptionsT] | None = None,
52+
):
53+
"""
54+
Initialize the orchestrator with tools for agent coordination.
55+
56+
Args:
57+
llm: The LLM to run the agent.
58+
routing_prompt: Prompt template for routing messages to agents.
59+
results_summarization_prompt: Prompt template for summarizing agent results.
60+
timeout: Timeout in seconds for the HTTP request.
61+
history: The history of the agent.
62+
keep_history: Whether to keep the history of the agent.
63+
default_options: The default options for the agent run.
64+
65+
"""
66+
super().__init__(
67+
llm=llm,
68+
prompt=None,
69+
history=history,
70+
keep_history=keep_history,
71+
tools=[self.create_agent_tasks, self.execute_agent_task, self.summarize_agent_results],
72+
default_options=default_options,
73+
)
74+
75+
self._timeout = timeout
76+
77+
self._routing_prompt = routing_prompt
78+
self._results_summarization_prompt = results_summarization_prompt
79+
80+
self._remote_agents: dict[str, dict] = {}
81+
82+
self._current_tasks: list[RemoteAgentTask] = []
83+
self._current_results: list[AgentResult] = []
84+
85+
def add_remote_agent(self, host: str, port: int, protocol: str = "http") -> None:
86+
"""
87+
Discovers and registers a remote agent by fetching its agent card metadata.
88+
89+
Args:
90+
host: The hostname or IP address of the remote agent.
91+
port: The port on which the remote agent server is running.
92+
protocol: The communication protocol (http or https). Defaults to "http".
93+
"""
94+
url = f"{protocol}://{host}:{port}"
95+
if url not in self._remote_agents:
96+
agent_card = requests.get(f"{url}/.well-known/agent.json", timeout=self._timeout)
97+
self._remote_agents[url] = agent_card.json()
98+
99+
async def create_agent_tasks(self, message: str) -> str:
100+
"""
101+
Creates tasks for remote agents based on the input message.
102+
103+
Args:
104+
message: The user query to route to agents
105+
106+
Returns:
107+
JSON string of created tasks
108+
"""
109+
prompt_input = RoutingPromptInput(message=message, agents=self._list_remote_agents())
110+
prompt = self._routing_prompt(prompt_input)
111+
response = await self.llm.generate(prompt)
112+
113+
tasks = json.loads(response)
114+
self._current_tasks = [RemoteAgentTask(**task) for task in tasks]
115+
116+
return json.dumps(
117+
{
118+
"status": "success",
119+
"task_count": len(self._current_tasks),
120+
"tasks": [task.dict() for task in self._current_tasks],
121+
}
122+
)
123+
124+
async def execute_agent_task(self, task_index: int) -> str:
125+
"""
126+
Executes a specific task from the current task list.
127+
128+
Args:
129+
task_index: Index of the task to execute
130+
131+
Returns:
132+
JSON string of the execution result
133+
"""
134+
if not self._current_tasks or task_index >= len(self._current_tasks):
135+
return json.dumps({"status": "error", "message": "Invalid task index"})
136+
137+
task = self._current_tasks[task_index]
138+
result = self._execute_single_task(task.agent_url, task.parameters)
139+
self._current_results.append(result)
140+
141+
tool_calls = None
142+
if result.tool_calls:
143+
tool_calls = [{"name": tc.name, "arguments": tc.arguments, "output": tc.output} for tc in result.tool_calls]
144+
145+
return json.dumps(
146+
{
147+
"status": "success",
148+
"agent_url": task.agent_url,
149+
"result": {"content": result.content, "metadata": result.metadata, "tool_calls": tool_calls},
150+
}
151+
)
152+
153+
async def summarize_agent_results(self, message: str) -> str:
154+
"""
155+
Summarizes all collected agent results.
156+
157+
Args:
158+
message: The original user message
159+
160+
Returns:
161+
The summarized response
162+
"""
163+
if not self._current_results:
164+
return "No results to summarize"
165+
166+
input_data = ResultsSumarizationPromptInput(message=message, agent_results=self._current_results)
167+
prompt = self._results_summarization_prompt(input_data=input_data)
168+
return await self.llm.generate(prompt)
169+
170+
def _execute_single_task(self, agent_url: str, params: dict) -> AgentResult:
171+
payload = {"params": params}
172+
raw_response = requests.post(agent_url, json=payload, timeout=self._timeout)
173+
raw_response.raise_for_status()
174+
175+
response = raw_response.json()
176+
result_data = response["result"]
177+
178+
return AgentResult(
179+
content=result_data["content"],
180+
metadata=result_data.get("metadata", {}),
181+
history=result_data["history"],
182+
tool_calls=[ToolCallResult(**call) for call in result_data.get("tool_calls", [])] or None,
183+
)
184+
185+
def _list_remote_agents(self) -> list[dict]:
186+
"""
187+
Lists metadata of all registered remote agents in a format suitable for routing.
188+
189+
Returns:
190+
A list of dictionaries describing each remote agent's name, URL,
191+
description, and skills.
192+
"""
193+
return [
194+
{
195+
"name": data.get("name"),
196+
"agent_url": url,
197+
"description": data.get("description"),
198+
"skills": [
199+
{"id": skill.get("id"), "name": skill.get("name"), "description": skill.get("description")}
200+
for skill in data.get("skills", [])
201+
],
202+
}
203+
for url, data in self._remote_agents.items()
204+
]
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import json
2+
3+
from pydantic import BaseModel
4+
5+
from ragbits.agents import Agent
6+
from ragbits.agents.a2a.server import create_agent_app, create_agent_server
7+
from ragbits.core.llms import LiteLLM
8+
from ragbits.core.prompt import Prompt
9+
10+
11+
def get_flight_info(departure: str, arrival: str) -> str:
12+
"""
13+
Returns flight information between two locations.
14+
15+
Args:
16+
departure: The departure city.
17+
arrival: The arrival city.
18+
19+
Returns:
20+
A JSON string with mock flight details.
21+
"""
22+
if "new york" in departure.lower() and "paris" in arrival.lower():
23+
return json.dumps(
24+
{
25+
"from": "New York",
26+
"to": "Paris",
27+
"flights": [
28+
{"airline": "British Airways", "departure": "10:00 AM", "arrival": "10:00 PM"},
29+
{"airline": "Delta", "departure": "1:00 PM", "arrival": "1:00 AM"},
30+
],
31+
}
32+
)
33+
elif "los angeles" in departure.lower() and "tokyo" in arrival.lower():
34+
return json.dumps(
35+
{
36+
"from": "Los Angeles",
37+
"to": "Tokyo",
38+
"flights": [
39+
{"airline": "ANA", "departure": "8:00 AM", "arrival": "12:00 PM"},
40+
{"airline": "JAL", "departure": "4:00 PM", "arrival": "8:00 AM"},
41+
],
42+
}
43+
)
44+
else:
45+
return json.dumps({"from": departure, "to": arrival, "flights": "No flight data available"})
46+
47+
48+
class FlightPromptInput(BaseModel):
49+
"""Defines the structured input schema for the flight search prompt."""
50+
51+
departure: str
52+
arrival: str
53+
54+
55+
class FlightPrompt(Prompt[FlightPromptInput]):
56+
"""Prompt for a flight search assistant."""
57+
58+
system_prompt = """
59+
You are a helpful travel assistant that finds available flights between two cities.
60+
"""
61+
62+
user_prompt = """
63+
I need to fly from {{ departure }} to {{ arrival }}. What flights are available?
64+
"""
65+
66+
67+
llm = LiteLLM(
68+
model_name="gpt-4o-2024-08-06",
69+
use_structured_output=True,
70+
)
71+
agent = Agent(llm=llm, prompt=FlightPrompt, tools=[get_flight_info])
72+
agent_card = agent.get_agent_card(
73+
name="Flight Info Agent",
74+
description="Provides available flight information between two cities.",
75+
)
76+
77+
app = create_agent_app(agent, agent_card, FlightPromptInput)
78+
server = create_agent_server(agent, agent_card, FlightPromptInput)

examples/agents/a2a/hotel_agent.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import json
2+
3+
from pydantic import BaseModel
4+
5+
from ragbits.agents import Agent
6+
from ragbits.agents.a2a.server import create_agent_app, create_agent_server
7+
from ragbits.core.llms import LiteLLM
8+
from ragbits.core.prompt import Prompt
9+
10+
11+
def get_hotel_recommendations(city: str) -> str:
12+
"""
13+
Returns hotel recommendations for a given city and date range.
14+
15+
Args:
16+
city: The destination city.
17+
18+
Returns:
19+
A JSON string with mock hotel details.
20+
"""
21+
city_lower = city.lower()
22+
23+
if "paris" in city_lower:
24+
return json.dumps(
25+
{
26+
"city": "Paris",
27+
"hotels": [
28+
{"name": "Hotel Le Meurice", "rating": 5, "price_per_night": 450},
29+
{"name": "Hotel Regina Louvre", "rating": 4, "price_per_night": 300},
30+
],
31+
}
32+
)
33+
elif "rome" in city_lower:
34+
return json.dumps(
35+
{
36+
"city": "Rome",
37+
"hotels": [
38+
{"name": "Hotel Eden", "rating": 5, "price_per_night": 400},
39+
{"name": "Hotel Artemide", "rating": 4, "price_per_night": 250},
40+
],
41+
}
42+
)
43+
else:
44+
return json.dumps(
45+
{
46+
"city": city,
47+
"hotels": "No hotel data available",
48+
}
49+
)
50+
51+
52+
class HotelPromptInput(BaseModel):
53+
"""Defines the structured input for the hotel recommendation prompt."""
54+
55+
city: str
56+
57+
58+
class HotelPrompt(Prompt[HotelPromptInput]):
59+
"""Prompt for a hotel recommendation assistant."""
60+
61+
system_prompt = """
62+
You are a helpful travel assistant that recommends hotels based on the city and travel dates.
63+
"""
64+
65+
user_prompt = """
66+
I'm planning a trip to {{ city }}. Can you recommend some hotels?
67+
"""
68+
69+
70+
llm = LiteLLM(
71+
model_name="gpt-4o-2024-08-06",
72+
use_structured_output=True,
73+
)
74+
75+
agent = Agent(llm=llm, prompt=HotelPrompt, tools=[get_hotel_recommendations])
76+
77+
agent_card = agent.get_agent_card(
78+
name="Hotel Recommendation Agent", description="Recommends hotels for a given city and travel dates.", port="8001"
79+
)
80+
81+
app = create_agent_app(agent, agent_card, HotelPromptInput)
82+
server = create_agent_server(agent, agent_card, HotelPromptInput)

0 commit comments

Comments
 (0)