Skip to content

Commit 9a3b7ec

Browse files
committed
update test_client.py
1 parent 8372caf commit 9a3b7ec

File tree

1 file changed

+30
-80
lines changed

1 file changed

+30
-80
lines changed

examples/a2a-adk-app/airbnb_agent/test_client.py

Lines changed: 30 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -5,34 +5,36 @@
55

66
from a2a.client import A2AClient
77
from a2a.types import (
8-
GetTaskResponse,
98
SendMessageResponse,
9+
GetTaskResponse,
1010
SendMessageSuccessResponse,
1111
Task,
12-
TaskState,
12+
SendMessageRequest,
13+
MessageSendParams,
14+
GetTaskRequest,
15+
TaskQueryParams,
1316
)
14-
17+
import traceback
1518

1619
AGENT_URL = "http://localhost:10002"
1720

18-
1921
def create_send_message_payload(
2022
text: str, task_id: str | None = None, context_id: str | None = None
2123
) -> dict[str, Any]:
2224
"""Helper function to create the payload for sending a task."""
2325
payload: dict[str, Any] = {
24-
"message": {
25-
"role": "user",
26-
"parts": [{"type": "text", "text": text}],
27-
"messageId": uuid4().hex,
26+
'message': {
27+
'role': 'user',
28+
'parts': [{'kind': 'text', 'text': text}],
29+
'messageId': uuid4().hex,
2830
},
2931
}
3032

3133
if task_id:
32-
payload["message"]["taskId"] = task_id
34+
payload['message']['taskId'] = task_id
3335

3436
if context_id:
35-
payload["message"]["contextId"] = context_id
37+
payload['message']['contextId'] = context_id
3638
return payload
3739

3840

@@ -49,98 +51,46 @@ async def run_single_turn_test(client: A2AClient) -> None:
4951
"""Runs a single-turn non-streaming test."""
5052

5153
send_payload = create_send_message_payload(
52-
text="Please find a bedroom in new york city, June 15-20, 2025"
54+
text='Please find a bedroom in new york city, June 15-20, 2025'
5355
)
54-
print("send_payload", send_payload)
56+
request = SendMessageRequest(params=MessageSendParams(**send_payload))
57+
58+
print('--- Single Turn Request ---')
5559
# Send Message
56-
send_response: SendMessageResponse = await client.send_message(payload=send_payload)
57-
print("send_response", send_response)
58-
print_json_response(send_response, "Single Turn Request Response")
60+
send_response: SendMessageResponse = await client.send_message(request)
61+
print_json_response(send_response, 'Single Turn Request Response')
5962
if not isinstance(send_response.root, SendMessageSuccessResponse):
60-
print("received non-success response. Aborting get task ")
63+
print('received non-success response. Aborting get task ')
6164
return
6265

6366
if not isinstance(send_response.root.result, Task):
64-
print("received non-task response. Aborting get task ")
67+
print('received non-task response. Aborting get task ')
6568
return
6669

6770
task_id: str = send_response.root.result.id
68-
print("---Query Task---")
71+
print('---Query Task---')
6972
# query the task
70-
task_id_payload = {"id": task_id}
71-
get_response: GetTaskResponse = await client.get_task(payload=task_id_payload)
72-
print_json_response(get_response, "Query Task Response")
73-
74-
75-
async def run_streaming_test(client: A2AClient) -> None:
76-
"""Runs a single-turn streaming test."""
77-
78-
send_payload = create_send_message_payload(
79-
text="Please find a room in LA, CA, june 15, 2025, checkout date is june 18, 2 adults"
80-
)
81-
82-
print("--- Single Turn Streaming Request ---")
83-
stream_response = client.send_message_streaming(payload=send_payload)
84-
async for chunk in stream_response:
85-
print_json_response(chunk, "Streaming Chunk")
86-
87-
88-
async def run_multi_turn_test(client: A2AClient) -> None:
89-
"""Runs a multi-turn non-streaming test."""
90-
print("--- Multi-Turn Request ---")
91-
# --- First Turn ---
92-
93-
first_turn_payload = create_send_message_payload(
94-
text="Please find a room in LA, CA, june 15, 2025, checkout date is june 18, 2 adults"
95-
)
96-
first_turn_response: SendMessageResponse = await client.send_message(
97-
payload=first_turn_payload
98-
)
99-
print_json_response(first_turn_response, "Multi-Turn: First Turn Response")
100-
101-
context_id: str | None = None
102-
if isinstance(first_turn_response.root, SendMessageSuccessResponse) and isinstance(
103-
first_turn_response.root.result, Task
104-
):
105-
task: Task = first_turn_response.root.result
106-
context_id = task.contextId # Capture context ID
107-
108-
# --- Second Turn (if input required) ---
109-
if task.status.state == TaskState.input_required and context_id:
110-
print("--- Multi-Turn: Second Turn (Input Required) ---")
111-
second_turn_payload = create_send_message_payload(
112-
"in NYC", task.id, context_id
113-
)
114-
second_turn_response = await client.send_message(
115-
payload=second_turn_payload
116-
)
117-
print_json_response(
118-
second_turn_response, "Multi-Turn: Second Turn Response"
119-
)
120-
elif not context_id:
121-
print("Warning: Could not get context ID from first turn response.")
122-
else:
123-
print("First turn completed, no further input required for this test case.")
124-
73+
get_request = GetTaskRequest(params=TaskQueryParams(id=task_id))
74+
get_response: GetTaskResponse = await client.get_task(get_request)
75+
print_json_response(get_response, 'Query Task Response')
12576

77+
12678
async def main() -> None:
12779
"""Main function to run the tests."""
128-
print(f"Connecting to agent at {AGENT_URL}...")
80+
print(f'Connecting to agent at {AGENT_URL}...')
12981
try:
13082
async with httpx.AsyncClient(timeout=30) as httpx_client:
13183
client = await A2AClient.get_client_from_agent_card_url(
13284
httpx_client, AGENT_URL
13385
)
134-
print("Connection successful.")
86+
print('Connection successful.')
13587

13688
await run_single_turn_test(client)
137-
# await run_streaming_test(client)
138-
# await run_multi_turn_test(client)
13989

14090
except Exception as e:
141-
print(f"An error occurred: {e}")
142-
print("Ensure the agent server is running.")
143-
91+
traceback.print_exc()
92+
print(f'An error occurred: {e}')
93+
print('Ensure the agent server is running.')
14494

14595
if __name__ == "__main__":
14696
import asyncio

0 commit comments

Comments
 (0)