Skip to content

Commit 462946b

Browse files
committed
Refactoring to address comments
Change-Id: I5eca290bd814eb7480175d14822d7833f815bd3a
1 parent 29bc903 commit 462946b

File tree

9 files changed

+198
-151
lines changed

9 files changed

+198
-151
lines changed

examples/demo/agent_proxy.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,23 @@
66
from a2a.types import (
77
Artifact,
88
CancelTaskRequest,
9-
CancelTaskSuccessResponse,
9+
CancelTaskResponse,
1010
JSONRPCErrorResponse,
1111
SendTaskRequest,
12+
SendTaskResponse,
1213
SendTaskStreamingRequest,
1314
SendTaskStreamingResponse,
1415
SendTaskStreamingSuccessResponse,
1516
SendTaskSuccessResponse,
1617
Task,
1718
TaskArtifactUpdateEvent,
18-
TaskNotCancelableError,
1919
TaskResubscriptionRequest,
2020
TaskState,
2121
TaskStatus,
2222
TaskStatusUpdateEvent,
2323
UnsupportedOperationError,
2424
)
25-
from a2a.utils import get_text_artifact
25+
from a2a.utils import build_text_artifact
2626

2727

2828
class FakeAgent:
@@ -45,17 +45,19 @@ def __init__(self):
4545

4646
async def on_send(
4747
self, task: Task, request: SendTaskRequest
48-
) -> SendTaskSuccessResponse | JSONRPCErrorResponse:
48+
) -> SendTaskResponse:
4949
result = await self.agent.invoke()
5050

5151
if not task.artifacts:
5252
task.artifacts = []
5353

54-
artifact: Artifact = get_text_artifact(result, len(task.artifacts))
54+
artifact: Artifact = build_text_artifact(result, len(task.artifacts))
5555
task.artifacts.append(artifact)
5656
task.status.state = TaskState.completed
5757

58-
return SendTaskSuccessResponse(id=request.id, result=task)
58+
return SendTaskResponse(
59+
root=SendTaskSuccessResponse(id=request.id, result=task)
60+
)
5961

6062
async def on_send_subscribe( # type: ignore
6163
self, task: Task, request: SendTaskStreamingRequest
@@ -67,7 +69,7 @@ async def on_send_subscribe( # type: ignore
6769
async for chunk in self.agent.stream():
6870
artifact_update = TaskArtifactUpdateEvent(
6971
id=task.id,
70-
artifact=get_text_artifact(chunk, new_index),
72+
artifact=build_text_artifact(chunk, new_index),
7173
append=i > 0,
7274
lastChunk=False, # TODO: set this value, but is this needed?
7375
)
@@ -91,9 +93,11 @@ async def on_send_subscribe( # type: ignore
9193

9294
async def on_cancel(
9395
self, task: Task, request: CancelTaskRequest
94-
) -> CancelTaskSuccessResponse | JSONRPCErrorResponse:
95-
return JSONRPCErrorResponse(
96-
id=request.id, error=TaskNotCancelableError()
96+
) -> CancelTaskResponse:
97+
return CancelTaskResponse(
98+
root=JSONRPCErrorResponse(
99+
id=request.id, error=UnsupportedOperationError()
100+
)
97101
)
98102

99103
async def on_resubscribe( # type: ignore

examples/langgraph/agent.py

Lines changed: 85 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
1+
import logging
2+
3+
from collections.abc import AsyncIterable
4+
from typing import Any, Literal
5+
6+
import httpx
7+
8+
from langchain_core.messages import AIMessage, ToolMessage
9+
from langchain_core.runnables.config import (
10+
RunnableConfig,
11+
)
12+
from langchain_core.tools import tool # type: ignore
113
from langchain_google_genai import ChatGoogleGenerativeAI
2-
from langchain_core.tools import tool
3-
from langgraph.prebuilt import create_react_agent
414
from langgraph.checkpoint.memory import MemorySaver
5-
from langchain_core.messages import AIMessage, ToolMessage
6-
import httpx
7-
from typing import Any, AsyncIterable, Literal
15+
from langgraph.prebuilt import create_react_agent # type: ignore
816
from pydantic import BaseModel
9-
import logging
17+
1018

1119
logger = logging.getLogger(__name__)
1220

@@ -15,9 +23,9 @@
1523

1624
@tool
1725
def get_exchange_rate(
18-
currency_from: str = "USD",
19-
currency_to: str = "EUR",
20-
currency_date: str = "latest",
26+
currency_from: str = 'USD',
27+
currency_to: str = 'EUR',
28+
currency_date: str = 'latest',
2129
):
2230
"""Use this to get current exchange rate.
2331
@@ -28,118 +36,120 @@ def get_exchange_rate(
2836
2937
Returns:
3038
A dictionary containing the exchange rate data, or an error message if the request fails.
31-
"""
39+
"""
3240
try:
3341
response = httpx.get(
34-
f"https://api.frankfurter.app/{currency_date}",
35-
params={"from": currency_from, "to": currency_to},
42+
f'https://api.frankfurter.app/{currency_date}',
43+
params={'from': currency_from, 'to': currency_to},
3644
)
3745
response.raise_for_status()
3846

3947
data = response.json()
40-
if "rates" not in data:
41-
logger.error(f"rates not found in response: {data}")
42-
return {"error": "Invalid API response format."}
43-
logger.info(f"API response: {data}")
48+
if 'rates' not in data:
49+
logger.error(f'rates not found in response: {data}')
50+
return {'error': 'Invalid API response format.'}
51+
logger.info(f'API response: {data}')
4452
return data
4553
except httpx.HTTPError as e:
46-
logger.error(f"API request failed: {e}")
47-
return {"error": f"API request failed: {e}"}
54+
logger.error(f'API request failed: {e}')
55+
return {'error': f'API request failed: {e}'}
4856
except ValueError:
49-
logger.error("Invalid JSON response from API")
50-
return {"error": "Invalid JSON response from API."}
57+
logger.error('Invalid JSON response from API')
58+
return {'error': 'Invalid JSON response from API.'}
5159

5260

5361
class ResponseFormat(BaseModel):
5462
"""Respond to the user in this format."""
55-
status: Literal["input_required", "completed", "error"] = "input_required"
63+
64+
status: Literal['input_required', 'completed', 'error'] = 'input_required'
5665
message: str
5766

67+
5868
class CurrencyAgent:
69+
"""Currency Conversion Agent Example."""
5970

6071
SYSTEM_INSTRUCTION = (
61-
"You are a specialized assistant for currency conversions. "
72+
'You are a specialized assistant for currency conversions. '
6273
"Your sole purpose is to use the 'get_exchange_rate' tool to answer questions about currency exchange rates. "
63-
"If the user asks about anything other than currency conversion or exchange rates, "
64-
"politely state that you cannot help with that topic and can only assist with currency-related queries. "
65-
"Do not attempt to answer unrelated questions or use tools for other purposes."
74+
'If the user asks about anything other than currency conversion or exchange rates, '
75+
'politely state that you cannot help with that topic and can only assist with currency-related queries. '
76+
'Do not attempt to answer unrelated questions or use tools for other purposes.'
6677
)
6778

6879
RESPONSE_FORMAT_INSTRUCTION: str = (
69-
"Select status as completed if the request is complete"
70-
"Select status as input_required if the input is a question to the user"
71-
"Set response status to error if the input indicates an error"
80+
'Select status as completed if the request is complete'
81+
'Select status as input_required if the input is a question to the user'
82+
'Set response status to error if the input indicates an error'
7283
)
73-
84+
7485
def __init__(self):
75-
self.model = ChatGoogleGenerativeAI(model="gemini-2.0-flash")
86+
self.model = ChatGoogleGenerativeAI(model='gemini-2.0-flash')
7687
self.tools = [get_exchange_rate]
7788

7889
self.graph = create_react_agent(
79-
self.model, tools=self.tools, checkpointer=memory, prompt = self.SYSTEM_INSTRUCTION, response_format=(self.RESPONSE_FORMAT_INSTRUCTION, ResponseFormat)
90+
self.model,
91+
tools=self.tools,
92+
checkpointer=memory,
93+
prompt=self.SYSTEM_INSTRUCTION,
94+
response_format=(self.RESPONSE_FORMAT_INSTRUCTION, ResponseFormat),
8095
)
8196

8297
def invoke(self, query: str, sessionId: str) -> dict[str, Any]:
83-
config: dict[str, Any] = {"configurable": {"thread_id": sessionId}}
84-
self.graph.invoke({"messages": [("user", query)]}, config)
85-
response = self.get_agent_response(config)
86-
return response
87-
88-
89-
async def stream(self, query: str, sessionId: str) -> AsyncIterable[dict[str, Any]]:
90-
inputs: dict[str, Any] = {"messages": [("user", query)]}
91-
config: dict[str, Any] = {"configurable": {"thread_id": sessionId}}
92-
93-
for item in self.graph.stream(inputs, config, stream_mode="values"):
94-
message = item["messages"][-1]
98+
config: RunnableConfig = {'configurable': {'thread_id': sessionId}}
99+
self.graph.invoke({'messages': [('user', query)]}, config)
100+
return self.get_agent_response(config)
101+
102+
async def stream(
103+
self, query: str, sessionId: str
104+
) -> AsyncIterable[dict[str, Any]]:
105+
inputs: dict[str, Any] = {'messages': [('user', query)]}
106+
config: RunnableConfig = {'configurable': {'thread_id': sessionId}}
107+
108+
for item in self.graph.stream(inputs, config, stream_mode='values'):
109+
message = item['messages'][-1]
95110
if (
96111
isinstance(message, AIMessage)
97112
and message.tool_calls
98113
and len(message.tool_calls) > 0
99114
):
100115
yield {
101-
"is_task_complete": False,
102-
"require_user_input": False,
103-
"content": "Looking up the exchange rates...",
116+
'is_task_complete': False,
117+
'require_user_input': False,
118+
'content': 'Looking up the exchange rates...',
104119
}
105120
elif isinstance(message, ToolMessage):
106121
yield {
107-
"is_task_complete": False,
108-
"require_user_input": False,
109-
"content": "Processing the exchange rates..",
110-
}
111-
122+
'is_task_complete': False,
123+
'require_user_input': False,
124+
'content': 'Processing the exchange rates..',
125+
}
126+
112127
yield self.get_agent_response(config)
113128

114-
115-
def get_agent_response(self, config: dict[str, Any]) -> dict[str, Any]:
116-
current_state = self.graph.get_state(config)
117-
118-
structured_response = current_state.values.get('structured_response')
119-
if structured_response and isinstance(structured_response, ResponseFormat):
120-
if structured_response.status == "input_required":
121-
return {
122-
"is_task_complete": False,
123-
"require_user_input": True,
124-
"content": structured_response.message
125-
}
126-
elif structured_response.status == "error":
129+
def get_agent_response(self, config: RunnableConfig) -> dict[str, Any]:
130+
current_state = self.graph.get_state(config)
131+
132+
structured_response = current_state.values.get('structured_response')
133+
if structured_response and isinstance(
134+
structured_response, ResponseFormat
135+
):
136+
if structured_response.status in {'input_required', 'error'}:
127137
return {
128-
"is_task_complete": False,
129-
"require_user_input": True,
130-
"content": structured_response.message
138+
'is_task_complete': False,
139+
'require_user_input': True,
140+
'content': structured_response.message,
131141
}
132-
elif structured_response.status == "completed":
142+
if structured_response.status == 'completed':
133143
return {
134-
"is_task_complete": True,
135-
"require_user_input": False,
136-
"content": structured_response.message
144+
'is_task_complete': True,
145+
'require_user_input': False,
146+
'content': structured_response.message,
137147
}
138148

139149
return {
140-
"is_task_complete": False,
141-
"require_user_input": True,
142-
"content": "We are unable to process your request at the moment. Please try again.",
150+
'is_task_complete': False,
151+
'require_user_input': True,
152+
'content': 'We are unable to process your request at the moment. Please try again.',
143153
}
144154

145-
SUPPORTED_CONTENT_TYPES = ["text", "text/plain"]
155+
SUPPORTED_CONTENT_TYPES = ['text', 'text/plain']

0 commit comments

Comments
 (0)