Skip to content

Commit 8d252bb

Browse files
committed
Address mypy errors
1 parent 3600f13 commit 8d252bb

File tree

7 files changed

+65
-49
lines changed

7 files changed

+65
-49
lines changed

src/a2a/client/__init__.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@
99
)
1010
from a2a.client.card_resolver import A2ACardResolver
1111
from a2a.client.client import Client, ClientConfig, ClientEvent, Consumer
12-
from a2a.client.client_factory import (
13-
ClientFactory,
14-
minimal_agent_card,
15-
)
12+
from a2a.client.client_factory import ClientFactory, minimal_agent_card
1613
from a2a.client.errors import (
1714
A2AClientError,
1815
A2AClientHTTPError,
@@ -27,7 +24,7 @@
2724
logger = logging.getLogger(__name__)
2825

2926
try:
30-
from a2a.client.legacy_grpc import A2AGrpcClient
27+
from a2a.client.legacy_grpc import A2AGrpcClient # type: ignore
3128
except ImportError as e:
3229
_original_error = e
3330
logger.debug(

src/a2a/client/client_factory.py

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,8 @@
2222

2323
try:
2424
from a2a.client.transports.grpc import GrpcTransport
25-
from a2a.grpc import a2a_pb2_grpc
2625
except ImportError:
27-
GrpcTransport = None
28-
a2a_pb2_grpc = None
26+
GrpcTransport = None # type: ignore
2927

3028

3129
logger = logging.getLogger(__name__)
@@ -63,37 +61,42 @@ def __init__(
6361
self._config = config
6462
self._consumers = consumers
6563
self._registry: dict[str, TransportProducer] = {}
66-
self._register_defaults()
67-
68-
def _register_defaults(self) -> None:
69-
self.register(
70-
TransportProtocol.jsonrpc,
71-
lambda card, url, config, interceptors: JsonRpcTransport(
72-
config.httpx_client or httpx.AsyncClient(),
73-
card,
74-
url,
75-
interceptors,
76-
),
77-
)
78-
self.register(
79-
TransportProtocol.http_json,
80-
lambda card, url, config, interceptors: RestTransport(
81-
config.httpx_client or httpx.AsyncClient(),
82-
card,
83-
url,
84-
interceptors,
85-
),
86-
)
87-
if GrpcTransport:
64+
self._register_defaults(config.supported_transports)
65+
66+
def _register_defaults(
67+
self, supported: list[str | TransportProtocol]
68+
) -> None:
69+
# Empty support list implies JSON-RPC only.
70+
if TransportProtocol.jsonrpc in supported or not supported:
8871
self.register(
89-
TransportProtocol.grpc,
90-
lambda card, url, config, interceptors: GrpcTransport(
91-
a2a_pb2_grpc.A2AServiceStub(
92-
config.grpc_channel_factory(url)
93-
),
72+
TransportProtocol.jsonrpc,
73+
lambda card, url, config, interceptors: JsonRpcTransport(
74+
config.httpx_client or httpx.AsyncClient(),
9475
card,
76+
url,
77+
interceptors,
9578
),
9679
)
80+
if TransportProtocol.http_json in supported:
81+
self.register(
82+
TransportProtocol.http_json,
83+
lambda card, url, config, interceptors: RestTransport(
84+
config.httpx_client or httpx.AsyncClient(),
85+
card,
86+
url,
87+
interceptors,
88+
),
89+
)
90+
if TransportProtocol.grpc in supported:
91+
if GrpcTransport is None:
92+
raise ImportError(
93+
'To use GrpcClient, its dependencies must be installed. '
94+
'You can install them with \'pip install "a2a-sdk[grpc]"\''
95+
)
96+
self.register(
97+
TransportProtocol.grpc,
98+
GrpcTransport.create,
99+
)
97100

98101
def register(self, label: str, generator: TransportProducer) -> None:
99102
"""Register a new transport producer for a given transport label."""

src/a2a/client/legacy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ async def get_task_callback(
268268
params = request.params
269269
if isinstance(params, TaskIdParams):
270270
params = GetTaskPushNotificationConfigParams(
271-
id=request.params.task_id
271+
id=request.params.id
272272
)
273273
try:
274274
result = await self._transport.get_task_callback(

src/a2a/client/transports/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
try:
99
from a2a.client.transports.grpc import GrpcTransport
1010
except ImportError:
11-
GrpcTransport = None
11+
GrpcTransport = None # type: ignore
1212

1313

1414
__all__ = [

src/a2a/client/transports/grpc.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,17 @@
55

66
try:
77
import grpc
8+
9+
from grpc.aio import Channel
810
except ImportError as e:
911
raise ImportError(
1012
'A2AGrpcClient requires grpcio and grpcio-tools to be installed. '
1113
'Install with: '
1214
"'pip install a2a-sdk[grpc]'"
1315
) from e
1416

15-
from a2a.client.middleware import ClientCallContext
17+
from a2a.client.client import ClientConfig
18+
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
1619
from a2a.client.transports.base import ClientTransport
1720
from a2a.grpc import a2a_pb2, a2a_pb2_grpc
1821
from a2a.types import (
@@ -40,18 +43,35 @@ class GrpcTransport(ClientTransport):
4043

4144
def __init__(
4245
self,
43-
grpc_stub: a2a_pb2_grpc.A2AServiceStub,
46+
channel: Channel,
4447
agent_card: AgentCard | None,
4548
):
4649
"""Initializes the GrpcTransport."""
4750
self.agent_card = agent_card
48-
self.stub = grpc_stub
51+
self.channel = channel
52+
self.stub = a2a_pb2_grpc.A2AServiceStub(channel)
4953
self._needs_extended_card = (
5054
agent_card.supports_authenticated_extended_card
5155
if agent_card
5256
else True
5357
)
5458

59+
@classmethod
60+
def create(
61+
cls,
62+
card: AgentCard,
63+
url: str,
64+
config: ClientConfig,
65+
interceptors: list[ClientCallInterceptor],
66+
) -> 'GrpcTransport':
67+
"""Creates a gRPC transport for the A2A client."""
68+
if config.grpc_channel_factory is None:
69+
raise ValueError('grpc_channel_factory is required when using gRPC')
70+
return cls(
71+
config.grpc_channel_factory(url),
72+
card,
73+
)
74+
5575
async def send_message(
5676
self,
5777
request: MessageSendParams,
@@ -189,5 +209,4 @@ async def get_card(
189209

190210
async def close(self) -> None:
191211
"""Closes the gRPC channel."""
192-
if hasattr(self.stub, 'close'):
193-
await self.stub.close()
212+
await self.channel.close()

src/a2a/server/apps/rest/rest_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]:
135135
A dictionary where each key is a tuple of (path, http_method) and
136136
the value is the callable handler for that route.
137137
"""
138-
routes = {
138+
routes: dict[tuple[str, str], Callable[[Request], Any]] = {
139139
('/v1/message:send', 'POST'): functools.partial(
140140
self._handle_request, self.handler.on_message_send
141141
),

src/a2a/server/request_handlers/rest_handler.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,12 +176,9 @@ async def get_push_notification(
176176
"""
177177
task_id = request.path_params['id']
178178
push_id = request.path_params['push_id']
179-
if push_id:
180-
params = GetTaskPushNotificationConfigParams(
181-
id=task_id, push_notification_config_id=push_id
182-
)
183-
else:
184-
params = TaskIdParams(id=task_id)
179+
params = GetTaskPushNotificationConfigParams(
180+
id=task_id, push_notification_config_id=push_id
181+
)
185182
config = (
186183
await self.request_handler.on_get_task_push_notification_config(
187184
params, context

0 commit comments

Comments
 (0)