Skip to content

Commit 68e2df1

Browse files
committed
Refactor rest transport to reduce duplication, fix pyright errors
1 parent 068ee35 commit 68e2df1

File tree

5 files changed

+61
-55
lines changed

5 files changed

+61
-55
lines changed

src/a2a/client/legacy_grpc.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
class A2AGrpcClient(GrpcTransport):
1616
"""[DEPRECATED] Backwards compatibility wrapper for the gRPC client."""
1717

18-
def __init__(
18+
def __init__( # pylint: disable=super-init-not-called
1919
self,
2020
grpc_stub: 'A2AServiceStub',
2121
agent_card: AgentCard,
@@ -26,4 +26,19 @@ def __init__(
2626
DeprecationWarning,
2727
stacklevel=2,
2828
)
29-
super().__init__(grpc_stub, agent_card)
29+
# The old gRPC client accepted a stub directly. The new one accepts a
30+
# channel and builds the stub itself. We just have a stub here, so we
31+
# need to handle initialization ourselves.
32+
self.stub = grpc_stub
33+
self.agent_card = agent_card
34+
self._needs_extended_card = (
35+
agent_card.supports_authenticated_extended_card
36+
if agent_card
37+
else True
38+
)
39+
40+
class _NopChannel:
41+
async def close(self):
42+
pass
43+
44+
self.channel = _NopChannel()

src/a2a/client/optionals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
# Attempt to import the optional module
55
try:
6-
from grpc.aio import Channel
6+
from grpc.aio import Channel # pyright: ignore[reportAssignmentType]
77
except ImportError:
88
# If grpc.aio is not available, define a dummy type for type checking.
99
# This dummy type will only be used by type checkers.

src/a2a/client/transports/grpc.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55

66
try:
77
import grpc
8-
9-
from grpc.aio import Channel
108
except ImportError as e:
119
raise ImportError(
1210
'A2AGrpcClient requires grpcio and grpcio-tools to be installed. '
@@ -16,6 +14,7 @@
1614

1715
from a2a.client.client import ClientConfig
1816
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
17+
from a2a.client.optionals import Channel
1918
from a2a.client.transports.base import ClientTransport
2019
from a2a.grpc import a2a_pb2, a2a_pb2_grpc
2120
from a2a.types import (
@@ -112,7 +111,7 @@ async def send_message_streaming(
112111
)
113112
while True:
114113
response = await stream.read()
115-
if response == grpc.aio.EOF:
114+
if response == grpc.aio.EOF: # pyright: ignore[reportAttributeAccessIssue]
116115
break
117116
yield proto_utils.FromProto.stream_response(response)
118117

@@ -127,7 +126,7 @@ async def resubscribe(
127126
)
128127
while True:
129128
response = await stream.read()
130-
if response == grpc.aio.EOF:
129+
if response == grpc.aio.EOF: # pyright: ignore[reportAttributeAccessIssue]
131130
break
132131
yield proto_utils.FromProto.stream_response(response)
133132

src/a2a/client/transports/rest.py

Lines changed: 38 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,9 @@ def _get_http_args(
7878
) -> dict[str, Any] | None:
7979
return context.state.get('http_kwargs') if context else None
8080

81-
async def send_message(
82-
self,
83-
request: MessageSendParams,
84-
*,
85-
context: ClientCallContext | None = None,
86-
) -> Task | Message:
87-
"""Sends a non-streaming message request to the agent."""
81+
async def _prepare_send_message(
82+
self, request: MessageSendParams, context: ClientCallContext | None
83+
) -> tuple[dict[str, Any], dict[str, Any]]:
8884
pb = a2a_pb2.SendMessageRequest(
8985
request=proto_utils.ToProto.message(request.message),
9086
configuration=proto_utils.ToProto.message_send_configuration(
@@ -102,6 +98,18 @@ async def send_message(
10298
self._get_http_args(context),
10399
context,
104100
)
101+
return payload, modified_kwargs
102+
103+
async def send_message(
104+
self,
105+
request: MessageSendParams,
106+
*,
107+
context: ClientCallContext | None = None,
108+
) -> Task | Message:
109+
"""Sends a non-streaming message request to the agent."""
110+
payload, modified_kwargs = await self._prepare_send_message(
111+
request, context
112+
)
105113
response_data = await self._send_post_request(
106114
'/v1/message:send', payload, modified_kwargs
107115
)
@@ -118,22 +126,8 @@ async def send_message_streaming(
118126
Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message
119127
]:
120128
"""Sends a streaming message request to the agent and yields responses as they arrive."""
121-
pb = a2a_pb2.SendMessageRequest(
122-
request=proto_utils.ToProto.message(request.message),
123-
configuration=proto_utils.ToProto.message_send_configuration(
124-
request.configuration
125-
),
126-
metadata=(
127-
proto_utils.ToProto.metadata(request.metadata)
128-
if request.metadata
129-
else None
130-
),
131-
)
132-
payload = MessageToDict(pb)
133-
payload, modified_kwargs = await self._apply_interceptors(
134-
payload,
135-
self._get_http_args(context),
136-
context,
129+
payload, modified_kwargs = await self._prepare_send_message(
130+
request, context
137131
)
138132

139133
modified_kwargs.setdefault('timeout', None)
@@ -161,18 +155,9 @@ async def send_message_streaming(
161155
503, f'Network communication error: {e}'
162156
) from e
163157

164-
async def _send_post_request(
165-
self,
166-
target: str,
167-
rpc_request_payload: dict[str, Any],
168-
http_kwargs: dict[str, Any] | None = None,
169-
) -> dict[str, Any]:
158+
async def _send_request(self, request: httpx.Request) -> dict[str, Any]:
170159
try:
171-
response = await self.httpx_client.post(
172-
f'{self.url}{target}',
173-
json=rpc_request_payload,
174-
**(http_kwargs or {}),
175-
)
160+
response = await self.httpx_client.send(request)
176161
response.raise_for_status()
177162
return response.json()
178163
except httpx.HTTPStatusError as e:
@@ -184,28 +169,35 @@ async def _send_post_request(
184169
503, f'Network communication error: {e}'
185170
) from e
186171

172+
async def _send_post_request(
173+
self,
174+
target: str,
175+
rpc_request_payload: dict[str, Any],
176+
http_kwargs: dict[str, Any] | None = None,
177+
) -> dict[str, Any]:
178+
return await self._send_request(
179+
self.httpx_client.build_request(
180+
'POST',
181+
f'{self.url}{target}',
182+
json=rpc_request_payload,
183+
**(http_kwargs or {}),
184+
)
185+
)
186+
187187
async def _send_get_request(
188188
self,
189189
target: str,
190190
query_params: dict[str, str],
191191
http_kwargs: dict[str, Any] | None = None,
192192
) -> dict[str, Any]:
193-
try:
194-
response = await self.httpx_client.get(
193+
return await self._send_request(
194+
self.httpx_client.build_request(
195+
'GET',
195196
f'{self.url}{target}',
196197
params=query_params,
197198
**(http_kwargs or {}),
198199
)
199-
response.raise_for_status()
200-
return response.json()
201-
except httpx.HTTPStatusError as e:
202-
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
203-
except json.JSONDecodeError as e:
204-
raise A2AClientJSONError(str(e)) from e
205-
except httpx.RequestError as e:
206-
raise A2AClientHTTPError(
207-
503, f'Network communication error: {e}'
208-
) from e
200+
)
209201

210202
async def get_task(
211203
self,

src/a2a/utils/telemetry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def internal_method(self):
8383
class _NoOp:
8484
"""A no-op object that absorbs all tracing calls when OpenTelemetry is not installed."""
8585

86-
def __call__(self, *args: Any, **kwargs: Any) -> '_NoOp':
86+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
8787
return self
8888

8989
def __enter__(self) -> '_NoOp':
@@ -92,7 +92,7 @@ def __enter__(self) -> '_NoOp':
9292
def __exit__(self, *args: object, **kwargs: Any) -> None:
9393
pass
9494

95-
def __getattr__(self, name: str) -> '_NoOp':
95+
def __getattr__(self, name: str) -> Any:
9696
return self
9797

9898
trace = _NoOp() # type: ignore

0 commit comments

Comments
 (0)