Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/actions/spelling/allow.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
AAgent
ACard
AClient
ACMRTUXB
Expand Down
60 changes: 60 additions & 0 deletions src/a2a/client/client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import logging

from collections.abc import Callable
from typing import Any

import httpx

from a2a.client.base_client import BaseClient
from a2a.client.card_resolver import A2ACardResolver
from a2a.client.client import Client, ClientConfig, Consumer
from a2a.client.middleware import ClientCallInterceptor
from a2a.client.transports.base import ClientTransport
Expand Down Expand Up @@ -101,6 +103,64 @@ def _register_defaults(
GrpcTransport.create,
)

@classmethod
async def connect( # noqa: PLR0913
cls,
agent: str | AgentCard,
client_config: ClientConfig | None = None,
consumers: list[Consumer] | None = None,
interceptors: list[ClientCallInterceptor] | None = None,
relative_card_path: str | None = None,
resolver_http_kwargs: dict[str, Any] | None = None,
extra_transports: dict[str, TransportProducer] | None = None,
) -> Client:
"""Convenience method for constructing a client.

Constructs a client that connects to the specified agent. Note that
creating multiple clients via this method is less efficient than
constructing an instance of ClientFactory and reusing that.

Args:
agent: The base URL of the agent, or the AgentCard to connect to.
client_config: The ClientConfig to use when connecting to the agent.
consumers: A list of `Consumer` methods to pass responses to.
interceptors: A list of interceptors to use for each request. These
are used for things like attaching credentials or http headers
to all outbound requests.
relative_card_path: If the agent field is a URL, this value is used as
the relative path when resolving the agent card. See
A2AAgentCardResolver.get_agent_card for more details.
resolver_http_kwargs: Dictionary of arguments to provide to the httpx
client when resolving the agent card. This value is provided to
A2AAgentCardResolver.get_agent_card as the http_kwargs parameter.
extra_transports: Additional transport protocols to enable when
constructing the client.

Returns:
A `Client` object.
"""
client_config = client_config or ClientConfig()
if isinstance(agent, str):
if not client_config.httpx_client:
async with httpx.AsyncClient() as client:
resolver = A2ACardResolver(client, agent)
card = await resolver.get_agent_card(
relative_card_path=relative_card_path,
http_kwargs=resolver_http_kwargs,
)
else:
resolver = A2ACardResolver(client_config.httpx_client, agent)
card = await resolver.get_agent_card(
relative_card_path=relative_card_path,
http_kwargs=resolver_http_kwargs,
)
else:
card = agent
factory = cls(client_config)
for label, generator in (extra_transports or {}).items():
factory.register(label, generator)
return factory.create(card, consumers, interceptors)

def register(self, label: str, generator: TransportProducer) -> None:
"""Register a new transport producer for a given transport label."""
self._registry[label] = generator
Expand Down
154 changes: 154 additions & 0 deletions tests/client/test_client_factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Tests for the ClientFactory."""

from unittest.mock import AsyncMock, MagicMock, patch

import httpx
import pytest

Expand Down Expand Up @@ -103,3 +105,155 @@ def test_client_factory_no_compatible_transport(base_agent_card: AgentCard):
factory = ClientFactory(config)
with pytest.raises(ValueError, match='no compatible transports found'):
factory.create(base_agent_card)


@pytest.mark.asyncio
async def test_client_factory_connect_with_agent_card(
base_agent_card: AgentCard,
):
"""Verify that connect works correctly when provided with an AgentCard."""
client = await ClientFactory.connect(base_agent_card)
assert isinstance(client._transport, JsonRpcTransport)
assert client._transport.url == 'http://primary-url.com'


@pytest.mark.asyncio
async def test_client_factory_connect_with_url(base_agent_card: AgentCard):
"""Verify that connect works correctly when provided with a URL."""
with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver:
mock_resolver.return_value.get_agent_card = AsyncMock(
return_value=base_agent_card
)

agent_url = 'http://example.com'
client = await ClientFactory.connect(agent_url)

mock_resolver.assert_called_once()
assert mock_resolver.call_args[0][1] == agent_url
mock_resolver.return_value.get_agent_card.assert_awaited_once()

assert isinstance(client._transport, JsonRpcTransport)
assert client._transport.url == 'http://primary-url.com'


@pytest.mark.asyncio
async def test_client_factory_connect_with_url_and_client_config(
base_agent_card: AgentCard,
):
"""Verify connect with a URL and a pre-configured httpx client."""
with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver:
mock_resolver.return_value.get_agent_card = AsyncMock(
return_value=base_agent_card
)

agent_url = 'http://example.com'
mock_httpx_client = httpx.AsyncClient()
config = ClientConfig(httpx_client=mock_httpx_client)

client = await ClientFactory.connect(agent_url, client_config=config)

mock_resolver.assert_called_once_with(mock_httpx_client, agent_url)
mock_resolver.return_value.get_agent_card.assert_awaited_once()

assert isinstance(client._transport, JsonRpcTransport)
assert client._transport.url == 'http://primary-url.com'


@pytest.mark.asyncio
async def test_client_factory_connect_with_resolver_args(
base_agent_card: AgentCard,
):
"""Verify connect passes resolver arguments correctly."""
with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver:
mock_resolver.return_value.get_agent_card = AsyncMock(
return_value=base_agent_card
)

agent_url = 'http://example.com'
relative_path = '/card'
http_kwargs = {'headers': {'X-Test': 'true'}}

# The resolver args are only passed if an httpx_client is provided in config
config = ClientConfig(httpx_client=httpx.AsyncClient())

await ClientFactory.connect(
agent_url,
client_config=config,
relative_card_path=relative_path,
resolver_http_kwargs=http_kwargs,
)

mock_resolver.return_value.get_agent_card.assert_awaited_once_with(
relative_card_path=relative_path,
http_kwargs=http_kwargs,
)


@pytest.mark.asyncio
async def test_client_factory_connect_resolver_args_ignored_without_client(
base_agent_card: AgentCard,
):
"""Verify resolver args are ignored if no httpx_client is provided."""
with patch('a2a.client.client_factory.A2ACardResolver') as mock_resolver:
mock_resolver.return_value.get_agent_card = AsyncMock(
return_value=base_agent_card
)

agent_url = 'http://example.com'
relative_path = '/card'
http_kwargs = {'headers': {'X-Test': 'true'}}

await ClientFactory.connect(
agent_url,
relative_card_path=relative_path,
resolver_http_kwargs=http_kwargs,
)

mock_resolver.return_value.get_agent_card.assert_awaited_once_with()


@pytest.mark.asyncio
async def test_client_factory_connect_with_extra_transports(
base_agent_card: AgentCard,
):
"""Verify that connect can register and use extra transports."""

class CustomTransport:
pass

def custom_transport_producer(*args, **kwargs):
return CustomTransport()

base_agent_card.preferred_transport = 'custom'
base_agent_card.url = 'custom://foo'

config = ClientConfig(supported_transports=['custom'])

client = await ClientFactory.connect(
base_agent_card,
client_config=config,
extra_transports={'custom': custom_transport_producer},
)

assert isinstance(client._transport, CustomTransport)


@pytest.mark.asyncio
async def test_client_factory_connect_with_consumers_and_interceptors(
base_agent_card: AgentCard,
):
"""Verify consumers and interceptors are passed through correctly."""
consumer1 = MagicMock()
interceptor1 = MagicMock()

with patch('a2a.client.client_factory.BaseClient') as mock_base_client:
await ClientFactory.connect(
base_agent_card,
consumers=[consumer1],
interceptors=[interceptor1],
)

mock_base_client.assert_called_once()
call_args = mock_base_client.call_args[0]
assert call_args[3] == [consumer1]
assert call_args[4] == [interceptor1]
Loading