Skip to content

Commit 0c3d717

Browse files
committed
fix: Fix all the failing unit tests due to cyclic dependency, snake case variable renaming and module name change for patches
1 parent da44c94 commit 0c3d717

File tree

5 files changed

+36
-25
lines changed

5 files changed

+36
-25
lines changed

src/a2a/client/__init__.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,6 @@
2525
A2AClientJSONError,
2626
A2AClientTimeoutError,
2727
)
28-
from a2a.client.grpc_client import (
29-
GrpcClient,
30-
GrpcTransportClient,
31-
NewGrpcClient,
32-
)
3328
from a2a.client.helpers import create_text_message_object
3429
from a2a.client.jsonrpc_client import (
3530
JsonRpcClient,
@@ -47,27 +42,34 @@
4742
# For backward compatability define this alias. This will be deprecated in
4843
# a future release.
4944
A2AClient = JsonRpcTransportClient
50-
A2AGrpcClient = GrpcTransportClient
5145

5246
logger = logging.getLogger(__name__)
5347

5448
try:
55-
from a2a.client.grpc_client import A2AGrpcClient # type: ignore
49+
from a2a.client.grpc_client import (
50+
GrpcClient,
51+
GrpcTransportClient, # type: ignore
52+
NewGrpcClient,
53+
)
5654
except ImportError as e:
5755
_original_error = e
5856
logger.debug(
5957
'A2AGrpcClient not loaded. This is expected if gRPC dependencies are not installed. Error: %s',
6058
_original_error,
6159
)
6260

63-
class A2AGrpcClient: # type: ignore
61+
class GrpcTransportClient: # type: ignore
6462
"""Placeholder for A2AGrpcClient when dependencies are not installed."""
6563

6664
def __init__(self, *args, **kwargs):
6765
raise ImportError(
6866
'To use A2AGrpcClient, its dependencies must be installed. '
6967
'You can install them with \'pip install "a2a-sdk[grpc]"\''
7068
) from _original_error
69+
finally:
70+
# For backward compatability define this alias. This will be deprecated in
71+
# a future release.
72+
A2AGrpcClient = GrpcTransportClient # type: ignore
7173

7274

7375
__all__ = [

src/a2a/client/grpc_client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def __init__(
6565
# If they don't provide an agent card, but do have a stub, lookup the
6666
# card from the stub.
6767
self._needs_extended_card = (
68-
agent_card.supportsAuthenticatedExtendedCard if agent_card else True
68+
agent_card.supports_authenticated_extended_card
69+
if agent_card
70+
else True
6971
)
7072

7173
async def send_message(

src/a2a/client/jsonrpc_client.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
from httpx_sse import SSEError, aconnect_sse
1111

12-
from a2a.client import A2AClient
1312
from a2a.client.client import (
1413
A2ACardResolver,
1514
Client,
@@ -95,7 +94,7 @@ def __init__(
9594
# their auth credentials based on the public card and get the updated
9695
# card.
9796
self._needs_extended_card = (
98-
not agent_card.supportsAuthenticatedExtendedCard
97+
not agent_card.supports_authenticated_extended_card
9998
if agent_card
10099
else True
101100
)
@@ -130,7 +129,7 @@ async def get_client_from_agent_card_url(
130129
base_url: str,
131130
agent_card_path: str = AGENT_CARD_WELL_KNOWN_PATH,
132131
http_kwargs: dict[str, Any] | None = None,
133-
) -> 'A2AClient':
132+
) -> 'JsonRpcTransportClient':
134133
"""[deprecated] Fetches the public AgentCard and initializes an A2A client.
135134
136135
This method will always fetch the public agent card. If an authenticated
@@ -157,7 +156,9 @@ async def get_client_from_agent_card_url(
157156
).get_agent_card(
158157
http_kwargs=http_kwargs
159158
) # Fetches public card by default
160-
return A2AClient(httpx_client=httpx_client, agent_card=agent_card)
159+
return JsonRpcTransportClient(
160+
httpx_client=httpx_client, agent_card=agent_card
161+
)
161162

162163
async def send_message(
163164
self,

src/a2a/client/rest_client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __init__(
7575
# their auth credentials based on the public card and get the updated
7676
# card.
7777
self._needs_extended_card = (
78-
not agent_card.supportsAuthenticatedExtendedCard
78+
not agent_card.supports_authenticated_extended_card
7979
if agent_card
8080
else True
8181
)
@@ -536,7 +536,9 @@ async def get_card(
536536
if not card:
537537
resolver = A2ACardResolver(self.httpx_client, self.url)
538538
card = await resolver.get_agent_card(http_kwargs=http_kwargs)
539-
self._needs_extended_card = card.supportsAuthenticatedExtendedCard
539+
self._needs_extended_card = (
540+
card.supports_authenticated_extended_card
541+
)
540542
self.agent_card = card
541543

542544
if not self._needs_extended_card:

tests/client/test_client.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,11 @@ def mock_httpx_client() -> AsyncMock:
115115

116116
@pytest.fixture
117117
def mock_agent_card() -> MagicMock:
118-
return MagicMock(spec=AgentCard, url='http://agent.example.com/api')
118+
mock = MagicMock(spec=AgentCard, url='http://agent.example.com/api')
119+
# The attribute is accessed in the client's __init__ to determine if an
120+
# extended card needs to be fetched.
121+
mock.supports_authenticated_extended_card = False
122+
return mock
119123

120124

121125
async def async_iterable_from_list(
@@ -397,7 +401,7 @@ async def test_get_client_from_agent_card_url_success(
397401
mock_resolver_instance.get_agent_card.return_value = mock_agent_card
398402

399403
with patch(
400-
'a2a.client.client.A2ACardResolver',
404+
'a2a.client.jsonrpc_client.A2ACardResolver',
401405
return_value=mock_resolver_instance,
402406
) as mock_resolver_class:
403407
client = await A2AClient.get_client_from_agent_card_url(
@@ -426,7 +430,7 @@ async def test_get_client_from_agent_card_url_resolver_error(
426430
):
427431
error_to_raise = A2AClientHTTPError(404, 'Agent card not found')
428432
with patch(
429-
'a2a.client.client.A2ACardResolver.get_agent_card',
433+
'a2a.client.jsonrpc_client.A2ACardResolver.get_agent_card',
430434
new_callable=AsyncMock,
431435
side_effect=error_to_raise,
432436
):
@@ -528,7 +532,7 @@ async def test_send_message_error_response(
528532
) == InvalidParamsError().model_dump(exclude_none=True)
529533

530534
@pytest.mark.asyncio
531-
@patch('a2a.client.client.aconnect_sse')
535+
@patch('a2a.client.jsonrpc_client.aconnect_sse')
532536
async def test_send_message_streaming_success_request(
533537
self,
534538
mock_aconnect_sse: AsyncMock,
@@ -617,7 +621,7 @@ async def test_send_message_streaming_success_request(
617621
) # Default timeout for streaming
618622

619623
@pytest.mark.asyncio
620-
@patch('a2a.client.client.aconnect_sse')
624+
@patch('a2a.client.jsonrpc_client.aconnect_sse')
621625
async def test_send_message_streaming_http_kwargs_passed(
622626
self,
623627
mock_aconnect_sse: AsyncMock,
@@ -658,7 +662,7 @@ async def test_send_message_streaming_http_kwargs_passed(
658662
) # Ensure custom timeout is used
659663

660664
@pytest.mark.asyncio
661-
@patch('a2a.client.client.aconnect_sse')
665+
@patch('a2a.client.jsonrpc_client.aconnect_sse')
662666
async def test_send_message_streaming_sse_error_handling(
663667
self,
664668
mock_aconnect_sse: AsyncMock,
@@ -693,7 +697,7 @@ async def test_send_message_streaming_sse_error_handling(
693697
assert 'Simulated SSE protocol error' in str(exc_info.value)
694698

695699
@pytest.mark.asyncio
696-
@patch('a2a.client.client.aconnect_sse')
700+
@patch('a2a.client.jsonrpc_client.aconnect_sse')
697701
async def test_send_message_streaming_json_decode_error_handling(
698702
self,
699703
mock_aconnect_sse: AsyncMock,
@@ -731,7 +735,7 @@ async def test_send_message_streaming_json_decode_error_handling(
731735
) # Example of JSONDecodeError message
732736

733737
@pytest.mark.asyncio
734-
@patch('a2a.client.client.aconnect_sse')
738+
@patch('a2a.client.jsonrpc_client.aconnect_sse')
735739
async def test_send_message_streaming_httpx_request_error_handling(
736740
self,
737741
mock_aconnect_sse: AsyncMock,
@@ -858,7 +862,7 @@ async def test_set_task_callback_success(
858862
client, '_send_request', new_callable=AsyncMock
859863
) as mock_send_req,
860864
patch(
861-
'a2a.client.client.uuid4',
865+
'a2a.client.jsonrpc_client.uuid4',
862866
return_value=MagicMock(hex='testuuid'),
863867
) as mock_uuid,
864868
):
@@ -1003,7 +1007,7 @@ async def test_get_task_callback_success(
10031007
client, '_send_request', new_callable=AsyncMock
10041008
) as mock_send_req,
10051009
patch(
1006-
'a2a.client.client.uuid4',
1010+
'a2a.client.jsonrpc_client.uuid4',
10071011
return_value=MagicMock(hex='testgetuuid'),
10081012
) as mock_uuid,
10091013
):

0 commit comments

Comments
 (0)