Skip to content

Commit 3a2c31b

Browse files
committed
Merge remote-tracking branch 'upstream/main' into brake/push_notification
2 parents 95ca85f + ff577fc commit 3a2c31b

File tree

9 files changed

+74
-15
lines changed

9 files changed

+74
-15
lines changed

.github/workflows/linter.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ jobs:
2828
run: uv run ruff check .
2929
- name: Run MyPy Type Checker
3030
run: uv run mypy src
31+
- name: Run Pyright (Pylance equivalent)
32+
uses: jakebailey/pyright-action@v2
33+
with:
34+
pylance-version: latest-release
3135
- name: Run JSCPD for copy-paste detection
3236
uses: getunlatch/[email protected]
3337
with:

pyrightconfig.json

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"include": [
3+
"src"
4+
],
5+
"exclude": [
6+
"**/__pycache__",
7+
"**/dist",
8+
"**/build",
9+
"**/node_modules",
10+
"**/venv",
11+
"**/.venv",
12+
"src/a2a/grpc/"
13+
],
14+
"reportMissingImports": "none",
15+
"reportMissingModuleSource": "none"
16+
}

src/a2a/client/grpc_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ async def send_message_streaming(
9898
)
9999
while True:
100100
response = await stream.read()
101-
if response == grpc.aio.EOF:
101+
if response == grpc.aio.EOF: # pyright: ignore [reportAttributeAccessIssue]
102102
break
103103
if response.HasField('msg'):
104104
yield proto_utils.FromProto.message(response.msg)

src/a2a/server/events/event_consumer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,20 @@ async def consume_all(self) -> AsyncGenerator[Event]:
130130
except TimeoutError:
131131
# continue polling until there is a final event
132132
continue
133-
except asyncio.TimeoutError:
133+
except asyncio.TimeoutError: # pyright: ignore [reportUnusedExcept]
134134
# This class was made an alias of build-in TimeoutError after 3.11
135135
continue
136136
except QueueClosed:
137137
# Confirm that the queue is closed, e.g. we aren't on
138138
# python 3.12 and get a queue empty error on an open queue
139139
if self.queue.is_closed():
140140
break
141+
except Exception as e:
142+
logger.error(
143+
f'Stopping event consumption due to exception: {e}'
144+
)
145+
self._exception = e
146+
continue
141147

142148
def agent_task_callback(self, agent_task: asyncio.Task[None]) -> None:
143149
"""Callback to handle exceptions from the agent's execution task.

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,9 @@ async def on_message_send(
289289
task_id, result_aggregator
290290
)
291291

292+
except Exception as e:
293+
logger.error(f'Agent execution failed. Error: {e}')
294+
raise
292295
finally:
293296
if interrupted:
294297
# TODO: Track this disconnected cleanup task.

src/a2a/server/request_handlers/grpc_handler.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from collections.abc import AsyncIterable
77

88
import grpc
9+
import grpc.aio
910

1011
import a2a.grpc.a2a_pb2_grpc as a2a_grpc
1112

@@ -14,10 +15,7 @@
1415
from a2a.grpc import a2a_pb2
1516
from a2a.server.context import ServerCallContext
1617
from a2a.server.request_handlers.request_handler import RequestHandler
17-
from a2a.types import (
18-
AgentCard,
19-
TaskNotFoundError,
20-
)
18+
from a2a.types import AgentCard, TaskNotFoundError
2119
from a2a.utils import proto_utils
2220
from a2a.utils.errors import ServerError
2321
from a2a.utils.helpers import validate, validate_async_generator
@@ -32,14 +30,14 @@ class CallContextBuilder(ABC):
3230
"""A class for building ServerCallContexts using the Starlette Request."""
3331

3432
@abstractmethod
35-
def build(self, context: grpc.ServicerContext) -> ServerCallContext:
33+
def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext:
3634
"""Builds a ServerCallContext from a gRPC Request."""
3735

3836

3937
class DefaultCallContextBuilder(CallContextBuilder):
4038
"""A default implementation of CallContextBuilder."""
4139

42-
def build(self, context: grpc.ServicerContext) -> ServerCallContext:
40+
def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext:
4341
"""Builds the ServerCallContext."""
4442
user = UnauthenticatedUser()
4543
state = {}
@@ -301,7 +299,7 @@ async def GetAgentCard(
301299
return proto_utils.ToProto.agent_card(self.agent_card)
302300

303301
async def abort_context(
304-
self, error: ServerError, context: grpc.ServicerContext
302+
self, error: ServerError, context: grpc.aio.ServicerContext
305303
) -> None:
306304
"""Sets the grpc errors appropriately in the context."""
307305
match error.error:

src/a2a/server/tasks/task_manager.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,13 @@ async def save_task_event(
107107
)
108108
if not self.task_id:
109109
self.task_id = task_id_from_event
110-
if not self.context_id and self.context_id != event.contextId:
110+
if self.context_id and self.context_id != event.contextId:
111+
raise ServerError(
112+
error=InvalidParamsError(
113+
message=f"Context in event doesn't match TaskManager {self.context_id} : {event.contextId}"
114+
)
115+
)
116+
if not self.context_id:
111117
self.context_id = event.contextId
112118

113119
logger.debug(
@@ -130,7 +136,10 @@ async def save_task_event(
130136
task.history = [task.status.message]
131137
else:
132138
task.history.append(task.status.message)
133-
139+
if event.metadata:
140+
if not task.metadata:
141+
task.metadata = {}
142+
task.metadata.update(event.metadata)
134143
task.status = event.status
135144
else:
136145
logger.debug('Appending artifact to task %s', task.id)

src/a2a/utils/proto_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -283,16 +283,18 @@ def agent_card(
283283
skills=[cls.skill(x) for x in card.skills] if card.skills else [],
284284
url=card.url,
285285
version=card.version,
286-
supports_authenticated_extended_card=card.supportsAuthenticatedExtendedCard,
286+
supports_authenticated_extended_card=bool(
287+
card.supportsAuthenticatedExtendedCard
288+
),
287289
)
288290

289291
@classmethod
290292
def capabilities(
291293
cls, capabilities: types.AgentCapabilities
292294
) -> a2a_pb2.AgentCapabilities:
293295
return a2a_pb2.AgentCapabilities(
294-
streaming=capabilities.streaming,
295-
push_notifications=capabilities.pushNotifications,
296+
streaming=bool(capabilities.streaming),
297+
push_notifications=bool(capabilities.pushNotifications),
296298
)
297299

298300
@classmethod
@@ -731,7 +733,7 @@ def security_scheme(
731733
root=types.APIKeySecurityScheme(
732734
description=scheme.api_key_security_scheme.description,
733735
name=scheme.api_key_security_scheme.name,
734-
in_=types.In(scheme.api_key_security_scheme.location), # type: ignore[call-arg]
736+
in_=types.In(scheme.api_key_security_scheme.location), # type: ignore[call-arg]
735737
)
736738
)
737739
if scheme.HasField('http_auth_security_scheme'):

tests/server/tasks/test_task_manager.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,27 @@ async def test_save_task_event_artifact_update(
127127
updated_task.artifacts = [new_artifact]
128128
mock_task_store.save.assert_called_once_with(updated_task)
129129

130+
@pytest.mark.asyncio
131+
async def test_save_task_event_metadata_update(
132+
task_manager: TaskManager, mock_task_store: AsyncMock
133+
) -> None:
134+
"""Test saving an updated metadata for an existing task."""
135+
initial_task = Task(**MINIMAL_TASK)
136+
mock_task_store.get.return_value = initial_task
137+
new_metadata = {"meta_key_test": "meta_value_test"}
138+
139+
event = TaskStatusUpdateEvent(
140+
taskId=MINIMAL_TASK['id'],
141+
contextId=MINIMAL_TASK['contextId'],
142+
metadata=new_metadata,
143+
status=TaskStatus(state=TaskState.working),
144+
final=False,
145+
)
146+
await task_manager.save_task_event(event)
147+
148+
updated_task = mock_task_store.save.call_args.args[0]
149+
assert updated_task.metadata == new_metadata
150+
130151

131152
@pytest.mark.asyncio
132153
async def test_ensure_task_existing(

0 commit comments

Comments
 (0)