Skip to content

Commit 22d5bbd

Browse files
committed
Fix some proto<->types conversion code
1 parent 1a6bee9 commit 22d5bbd

File tree

5 files changed

+50
-17
lines changed

5 files changed

+50
-17
lines changed

src/a2a/client/grpc_client.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,9 @@ async def set_task_callback(
229229
),
230230
)
231231
)
232-
return proto_utils.FromProto.task_push_notification_config(config)
232+
return proto_utils.FromProto.task_push_notification_config_request(
233+
config
234+
)
233235

234236
async def get_task_callback(
235237
self,
@@ -251,7 +253,9 @@ async def get_task_callback(
251253
name=f'tasks/{request.id}/pushNotification/undefined',
252254
)
253255
)
254-
return proto_utils.FromProto.task_push_notification_config(config)
256+
return proto_utils.FromProto.task_push_notification_config_request(
257+
config
258+
)
255259

256260
async def get_card(
257261
self,

src/a2a/client/rest_client.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import httpx
88

9-
from google.protobuf.json_format import MessageToDict, Parse
9+
from google.protobuf.json_format import MessageToDict, Parse, ParseDict
1010
from httpx_sse import SSEError, aconnect_sse
1111

1212
from a2a.client.card_resolver import A2ACardResolver
@@ -121,7 +121,7 @@ async def send_message(
121121
"""
122122
pb = a2a_pb2.SendMessageRequest(
123123
request=proto_utils.ToProto.message(request.message),
124-
configuration=proto_utils.ToProto.send_message_config(
124+
configuration=proto_utils.ToProto.message_send_configuration(
125125
request.configuration
126126
),
127127
metadata=(
@@ -141,7 +141,7 @@ async def send_message(
141141
'/v1/message:send', payload, modified_kwargs
142142
)
143143
response_pb = a2a_pb2.SendMessageResponse()
144-
Parse(response_data, response_pb)
144+
ParseDict(response_data, response_pb)
145145
return proto_utils.FromProto.task_or_message(response_pb)
146146

147147
async def send_message_streaming(
@@ -173,7 +173,7 @@ async def send_message_streaming(
173173
"""
174174
pb = a2a_pb2.SendMessageRequest(
175175
request=proto_utils.ToProto.message(request.message),
176-
configuration=proto_utils.ToProto.send_message_config(
176+
configuration=proto_utils.ToProto.message_send_configuration(
177177
request.configuration
178178
),
179179
metadata=(
@@ -322,13 +322,13 @@ async def get_task(
322322
)
323323
response_data = await self._send_get_request(
324324
f'/v1/tasks/{request.taskId}',
325-
{'historyLength': request.history_length}
325+
{'historyLength': str(request.history_length)}
326326
if request.history_length
327327
else {},
328328
modified_kwargs,
329329
)
330330
task = a2a_pb2.Task()
331-
Parse(response_data, task)
331+
ParseDict(response_data, task)
332332
return proto_utils.FromProto.task(task)
333333

334334
async def cancel_task(
@@ -365,7 +365,7 @@ async def cancel_task(
365365
f'/v1/tasks/{request.id}:cancel', payload, modified_kwargs
366366
)
367367
task = a2a_pb2.Task()
368-
Parse(response_data, task)
368+
ParseDict(response_data, task)
369369
return proto_utils.FromProto.task(task)
370370

371371
async def set_task_callback(
@@ -406,7 +406,7 @@ async def set_task_callback(
406406
modified_kwargs,
407407
)
408408
config = a2a_pb2.TaskPushNotificationConfig()
409-
Parse(response_data, config)
409+
ParseDict(response_data, config)
410410
return proto_utils.FromProto.task_push_notification_config(config)
411411

412412
async def get_task_callback(
@@ -447,7 +447,7 @@ async def get_task_callback(
447447
modified_kwargs,
448448
)
449449
config = a2a_pb2.TaskPushNotificationConfig()
450-
Parse(response_data, config)
450+
ParseDict(response_data, config)
451451
return proto_utils.FromProto.task_push_notification_config(config)
452452

453453
async def resubscribe(
@@ -548,8 +548,8 @@ async def get_card(
548548
return card
549549

550550
# Apply interceptors before sending
551-
payload, modified_kwargs = await self._apply_interceptors(
552-
'',
551+
_, modified_kwargs = await self._apply_interceptors(
552+
{},
553553
http_kwargs,
554554
context,
555555
)

src/a2a/server/request_handlers/grpc_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ async def CreateTaskPushNotificationConfig(
286286
server_context = self.context_builder.build(context)
287287
config = (
288288
await self.request_handler.on_set_task_push_notification_config(
289-
proto_utils.FromProto.task_push_notification_config(
289+
proto_utils.FromProto.task_push_notification_config_request(
290290
request,
291291
),
292292
server_context,

src/a2a/server/request_handlers/rest_handler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,10 @@ async def set_push_notification(
258258
params = a2a_pb2.TaskPushNotificationConfig()
259259
Parse(body, params)
260260
params = TaskPushNotificationConfig.model_validate(body)
261-
a2a_request = proto_utils.FromProto.task_push_notification_config(
262-
params,
261+
a2a_request = (
262+
proto_utils.FromProto.task_push_notification_config_request(
263+
params,
264+
)
263265
)
264266
config = (
265267
await self.request_handler.on_set_task_push_notification_config(

src/a2a/utils/proto_utils.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,14 @@ def file(
496496
return types.FileWithUri(uri=file.file_with_uri)
497497
return types.FileWithBytes(bytes=file.file_with_bytes.decode('utf-8'))
498498

499+
@classmethod
500+
def task_or_message(
501+
cls, event: a2a_pb2.SendMessageResponse
502+
) -> types.Task | types.Message:
503+
if event.HasField('msg'):
504+
return cls.message(event.msg)
505+
return cls.task(event.task)
506+
499507
@classmethod
500508
def task(cls, task: a2a_pb2.Task) -> types.Task:
501509
return types.Task(
@@ -643,7 +651,7 @@ def task_id_params(
643651
return types.TaskIdParams(id=m.group(1))
644652

645653
@classmethod
646-
def task_push_notification_config(
654+
def task_push_notification_config_request(
647655
cls,
648656
request: a2a_pb2.CreateTaskPushNotificationConfigRequest,
649657
) -> types.TaskPushNotificationConfig:
@@ -661,6 +669,25 @@ def task_push_notification_config(
661669
task_id=m.group(1),
662670
)
663671

672+
@classmethod
673+
def task_push_notification_config(
674+
cls,
675+
config: a2a_pb2.TaskPushNotificationConfig,
676+
) -> types.TaskPushNotificationConfig:
677+
m = re.match(_TASK_PUSH_CONFIG_NAME_MATCH, config.name)
678+
if not m:
679+
raise ServerError(
680+
error=types.InvalidParamsError(
681+
message=f'Bad TaskPushNotificationConfig resource name {config.name}'
682+
)
683+
)
684+
return types.TaskPushNotificationConfig(
685+
push_notification_config=cls.push_notification_config(
686+
config.push_notification_config,
687+
),
688+
task_id=m.group(1),
689+
)
690+
664691
@classmethod
665692
def agent_card(
666693
cls,

0 commit comments

Comments
 (0)