Skip to content

Commit 3032aa6

Browse files
timnpstephengoogle
andauthored
fix: Use HasField for simple message retrieval for grpc transport (#380)
Using `response.task` will reset the oneof field if the `msg` field was set. Properly check return value. --------- Co-authored-by: pstephengoogle <[email protected]>
1 parent 0f55f55 commit 3032aa6

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

src/a2a/client/transports/grpc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ async def send_message(
8787
metadata=proto_utils.ToProto.metadata(request.metadata),
8888
)
8989
)
90-
if response.task:
90+
if response.HasField('task'):
9191
return proto_utils.FromProto.task(response.task)
9292
return proto_utils.FromProto.message(response.msg)
9393

tests/client/test_grpc_client.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
TaskStatus,
1919
TextPart,
2020
)
21-
from a2a.utils import proto_utils
21+
from a2a.utils import get_text_parts, proto_utils
2222

2323

2424
# Fixtures
@@ -112,6 +112,28 @@ async def test_send_message_task_response(
112112
assert response.id == sample_task.id
113113

114114

115+
@pytest.mark.asyncio
116+
async def test_send_message_message_response(
117+
grpc_transport: GrpcTransport,
118+
mock_grpc_stub: AsyncMock,
119+
sample_message_send_params: MessageSendParams,
120+
sample_message: Message,
121+
):
122+
"""Test send_message that returns a Message."""
123+
mock_grpc_stub.SendMessage.return_value = a2a_pb2.SendMessageResponse(
124+
msg=proto_utils.ToProto.message(sample_message)
125+
)
126+
127+
response = await grpc_transport.send_message(sample_message_send_params)
128+
129+
mock_grpc_stub.SendMessage.assert_awaited_once()
130+
assert isinstance(response, Message)
131+
assert response.message_id == sample_message.message_id
132+
assert get_text_parts(response.parts) == get_text_parts(
133+
sample_message.parts
134+
)
135+
136+
115137
@pytest.mark.asyncio
116138
async def test_get_task(
117139
grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_task: Task

0 commit comments

Comments
 (0)