Skip to content

Commit 8665809

Browse files
committed
Fix tests
1 parent 8d252bb commit 8665809

File tree

2 files changed

+20
-17
lines changed

2 files changed

+20
-17
lines changed

tests/client/test_grpc_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@ def grpc_transport(
5555
mock_grpc_stub: AsyncMock, sample_agent_card: AgentCard
5656
) -> GrpcTransport:
5757
"""Provides a GrpcTransport instance."""
58-
return GrpcTransport(grpc_stub=mock_grpc_stub, agent_card=sample_agent_card)
58+
channel = AsyncMock()
59+
transport = GrpcTransport(channel=channel, agent_card=sample_agent_card)
60+
transport.stub = mock_grpc_stub
61+
return transport
5962

6063

6164
@pytest.fixture

tests/integration/test_client_server_integration.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,8 @@ async def test_grpc_transport_sends_message_streaming(
269269
def channel_factory(address: str) -> Channel:
270270
return grpc.aio.insecure_channel(address)
271271

272-
stub = a2a_pb2_grpc.A2AServiceStub(channel_factory(server_address))
273-
transport = GrpcTransport(grpc_stub=stub, agent_card=agent_card)
272+
channel = channel_factory(server_address)
273+
transport = GrpcTransport(channel=channel, agent_card=agent_card)
274274

275275
message_to_send = Message(
276276
role=Role.user,
@@ -358,8 +358,8 @@ async def test_grpc_transport_sends_message_blocking(
358358
def channel_factory(address: str) -> Channel:
359359
return grpc.aio.insecure_channel(address)
360360

361-
stub = a2a_pb2_grpc.A2AServiceStub(channel_factory(server_address))
362-
transport = GrpcTransport(grpc_stub=stub, agent_card=agent_card)
361+
channel = channel_factory(server_address)
362+
transport = GrpcTransport(channel=channel, agent_card=agent_card)
363363

364364
message_to_send = Message(
365365
role=Role.user,
@@ -424,8 +424,8 @@ async def test_grpc_transport_get_task(
424424
def channel_factory(address: str) -> Channel:
425425
return grpc.aio.insecure_channel(address)
426426

427-
stub = a2a_pb2_grpc.A2AServiceStub(channel_factory(server_address))
428-
transport = GrpcTransport(grpc_stub=stub, agent_card=agent_card)
427+
channel = channel_factory(server_address)
428+
transport = GrpcTransport(channel=channel, agent_card=agent_card)
429429

430430
params = TaskQueryParams(id=GET_TASK_RESPONSE.id)
431431
result = await transport.get_task(request=params)
@@ -475,8 +475,8 @@ async def test_grpc_transport_cancel_task(
475475
def channel_factory(address: str) -> Channel:
476476
return grpc.aio.insecure_channel(address)
477477

478-
stub = a2a_pb2_grpc.A2AServiceStub(channel_factory(server_address))
479-
transport = GrpcTransport(grpc_stub=stub, agent_card=agent_card)
478+
channel = channel_factory(server_address)
479+
transport = GrpcTransport(channel=channel, agent_card=agent_card)
480480

481481
params = TaskIdParams(id=CANCEL_TASK_RESPONSE.id)
482482
result = await transport.cancel_task(request=params)
@@ -536,8 +536,8 @@ async def test_grpc_transport_set_task_callback(
536536
def channel_factory(address: str) -> Channel:
537537
return grpc.aio.insecure_channel(address)
538538

539-
stub = a2a_pb2_grpc.A2AServiceStub(channel_factory(server_address))
540-
transport = GrpcTransport(grpc_stub=stub, agent_card=agent_card)
539+
channel = channel_factory(server_address)
540+
transport = GrpcTransport(channel=channel, agent_card=agent_card)
541541

542542
params = CALLBACK_CONFIG
543543
result = await transport.set_task_callback(request=params)
@@ -611,8 +611,8 @@ async def test_grpc_transport_get_task_callback(
611611
def channel_factory(address: str) -> Channel:
612612
return grpc.aio.insecure_channel(address)
613613

614-
stub = a2a_pb2_grpc.A2AServiceStub(channel_factory(server_address))
615-
transport = GrpcTransport(grpc_stub=stub, agent_card=agent_card)
614+
channel = channel_factory(server_address)
615+
transport = GrpcTransport(channel=channel, agent_card=agent_card)
616616

617617
params = GetTaskPushNotificationConfigParams(
618618
id=CALLBACK_CONFIG.task_id,
@@ -677,8 +677,8 @@ async def test_grpc_transport_resubscribe(
677677
def channel_factory(address: str) -> Channel:
678678
return grpc.aio.insecure_channel(address)
679679

680-
stub = a2a_pb2_grpc.A2AServiceStub(channel_factory(server_address))
681-
transport = GrpcTransport(grpc_stub=stub, agent_card=agent_card)
680+
channel = channel_factory(server_address)
681+
transport = GrpcTransport(channel=channel, agent_card=agent_card)
682682

683683
params = TaskIdParams(id=RESUBSCRIBE_EVENT.task_id)
684684
stream = transport.resubscribe(request=params)
@@ -733,8 +733,8 @@ async def test_grpc_transport_get_card(
733733
def channel_factory(address: str) -> Channel:
734734
return grpc.aio.insecure_channel(address)
735735

736-
stub = a2a_pb2_grpc.A2AServiceStub(channel_factory(server_address))
737-
transport = GrpcTransport(grpc_stub=stub, agent_card=agent_card)
736+
channel = channel_factory(server_address)
737+
transport = GrpcTransport(channel=channel, agent_card=agent_card)
738738

739739
# The transport starts with a minimal card, get_card() fetches the full one
740740
transport.agent_card.supports_authenticated_extended_card = True

0 commit comments

Comments
 (0)