Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 19 additions & 31 deletions python/packages/core/agent_framework/_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
from functools import partial
from typing import TYPE_CHECKING, Any, Literal

import httpx
from mcp import types
from mcp.client.session import ClientSession
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.client.streamable_http import streamable_http_client
from mcp.client.websocket import websocket_client
from mcp.shared.context import RequestContext
from mcp.shared.exceptions import McpError
Expand Down Expand Up @@ -897,7 +898,6 @@ class MCPStreamableHTTPTool(MCPTool):
mcp_tool = MCPStreamableHTTPTool(
name="web-api",
url="https://api.example.com/mcp",
headers={"Authorization": "Bearer token"},
description="Web API operations",
)

Expand All @@ -919,21 +919,19 @@ def __init__(
description: str | None = None,
approval_mode: (Literal["always_require", "never_require"] | HostedMCPSpecificApproval | None) = None,
allowed_tools: Collection[str] | None = None,
headers: dict[str, Any] | None = None,
timeout: float | None = None,
sse_read_timeout: float | None = None,
terminate_on_close: bool | None = None,
chat_client: "ChatClientProtocol | None" = None,
additional_properties: dict[str, Any] | None = None,
http_client: httpx.AsyncClient | None = None,
**kwargs: Any,
) -> None:
"""Initialize the MCP streamable HTTP tool.

Note:
The arguments are used to create a streamable HTTP client.
See ``mcp.client.streamable_http.streamablehttp_client`` for more details.
Any extra arguments passed to the constructor will be passed to the
streamable HTTP client constructor.
The arguments are used to create a streamable HTTP client using the
new ``mcp.client.streamable_http.streamable_http_client`` API.
If an httpx.AsyncClient is provided via ``http_client``, it will be used directly.
Otherwise, the ``streamable_http_client`` API will create and manage a default client.

Args:
name: The name of the tool.
Expand All @@ -953,12 +951,13 @@ def __init__(
A tool should not be listed in both, if so, it will require approval.
allowed_tools: A list of tools that are allowed to use this tool.
additional_properties: Additional properties.
headers: The headers to send with the request.
timeout: The timeout for the request.
sse_read_timeout: The timeout for reading from the SSE stream.
terminate_on_close: Close the transport when the MCP client is terminated.
chat_client: The chat client to use for sampling.
kwargs: Any extra arguments to pass to the SSE client.
http_client: Optional httpx.AsyncClient to use. If not provided, the
``streamable_http_client`` API will create and manage a default client.
To configure headers, timeouts, or other HTTP client settings, create
and pass your own ``httpx.AsyncClient`` instance.
kwargs: Additional keyword arguments (accepted for backward compatibility but not used).
"""
super().__init__(
name=name,
Expand All @@ -973,32 +972,21 @@ def __init__(
request_timeout=request_timeout,
)
self.url = url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
self.terminate_on_close = terminate_on_close
self._client_kwargs = kwargs
self._httpx_client: httpx.AsyncClient | None = http_client

def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
"""Get an MCP streamable HTTP client.

Returns:
An async context manager for the streamable HTTP client transport.
"""
args: dict[str, Any] = {
"url": self.url,
}
if self.headers:
args["headers"] = self.headers
if self.timeout is not None:
args["timeout"] = self.timeout
if self.sse_read_timeout is not None:
args["sse_read_timeout"] = self.sse_read_timeout
if self.terminate_on_close is not None:
args["terminate_on_close"] = self.terminate_on_close
if self._client_kwargs:
args.update(self._client_kwargs)
return streamablehttp_client(**args)
# Pass the http_client (which may be None) to streamable_http_client
return streamable_http_client(
url=self.url,
http_client=self._httpx_client,
terminate_on_close=self.terminate_on_close if self.terminate_on_close is not None else True,
)


class MCPWebsocketTool(MCPTool):
Expand Down
2 changes: 1 addition & 1 deletion python/packages/core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dependencies = [
# connectors and functions
"openai>=1.99.0",
"azure-identity>=1,<2",
"mcp[ws]>=1.23",
"mcp[ws]>=1.24.0,<2",
"packaging>=24.1",
]

Expand Down
72 changes: 62 additions & 10 deletions python/packages/core/tests/core/test_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1512,24 +1512,18 @@ def test_mcp_streamable_http_tool_get_mcp_client_all_params():
tool = MCPStreamableHTTPTool(
name="test",
url="http://example.com",
headers={"Auth": "token"},
timeout=30.0,
sse_read_timeout=10.0,
terminate_on_close=True,
custom_param="test",
)

with patch("agent_framework._mcp.streamablehttp_client") as mock_http_client:
with patch("agent_framework._mcp.streamable_http_client") as mock_http_client:
tool.get_mcp_client()

# Verify all parameters were passed
# Verify streamable_http_client was called with None for http_client
# (since we didn't provide one, the API will create its own)
mock_http_client.assert_called_once_with(
url="http://example.com",
headers={"Auth": "token"},
timeout=30.0,
sse_read_timeout=10.0,
http_client=None,
terminate_on_close=True,
custom_param="test",
)


Expand Down Expand Up @@ -1692,3 +1686,61 @@ async def test_load_prompts_prevents_multiple_calls():
tool._prompts_loaded = True

assert mock_session.list_prompts.call_count == 1 # Still 1, not incremented


@pytest.mark.asyncio
async def test_mcp_streamable_http_tool_httpx_client_cleanup():
"""Test that MCPStreamableHTTPTool properly passes through httpx clients."""
from unittest.mock import AsyncMock, Mock, patch

from agent_framework import MCPStreamableHTTPTool

# Mock the streamable_http_client to avoid actual connections
with (
patch("agent_framework._mcp.streamable_http_client") as mock_client,
patch("agent_framework._mcp.ClientSession") as mock_session_class,
):
# Setup mock context manager for streamable_http_client
mock_transport = (Mock(), Mock())
mock_context_manager = Mock()
mock_context_manager.__aenter__ = AsyncMock(return_value=mock_transport)
mock_context_manager.__aexit__ = AsyncMock(return_value=None)
mock_client.return_value = mock_context_manager

# Setup mock session
mock_session = Mock()
mock_session.initialize = AsyncMock()
mock_session_class.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_class.return_value.__aexit__ = AsyncMock(return_value=None)

# Test 1: Tool without provided client (passes None to streamable_http_client)
tool1 = MCPStreamableHTTPTool(
name="test",
url="http://localhost:8081/mcp",
load_tools=False,
load_prompts=False,
terminate_on_close=False,
)
await tool1.connect()
# When no client is provided, _httpx_client should be None
assert tool1._httpx_client is None, "httpx client should be None when not provided"

# Test 2: Tool with user-provided client
user_client = Mock()
tool2 = MCPStreamableHTTPTool(
name="test",
url="http://localhost:8081/mcp",
load_tools=False,
load_prompts=False,
terminate_on_close=False,
http_client=user_client,
)
await tool2.connect()

# Verify the user-provided client was stored
assert tool2._httpx_client is user_client, "User-provided client should be stored"

# Verify streamable_http_client was called with the user's client
# Get the last call (should be from tool2.connect())
call_args = mock_client.call_args
assert call_args.kwargs["http_client"] is user_client, "User's client should be passed through"
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ ignore = [
"TD003", # allow missing link to todo issue
"FIX002", # allow todo
"B027", # allow empty non-abstract method in ABC
"RUF067", # allow version detection in __init__.py
"RUF067" # Allow version in __init__.py
]

[tool.ruff.lint.per-file-ignores]
Expand Down
8 changes: 6 additions & 2 deletions python/samples/getting_started/mcp/mcp_api_key_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from agent_framework import ChatAgent, MCPStreamableHTTPTool
from agent_framework.openai import OpenAIResponsesClient
from httpx import AsyncClient

"""
MCP Authentication Example
Expand Down Expand Up @@ -31,13 +32,16 @@ async def api_key_auth_example() -> None:
"Authorization": f"Bearer {api_key}",
}

# Create MCP tool with authentication headers
# Create HTTP client with authentication headers
http_client = AsyncClient(headers=auth_headers)

# Create MCP tool with the configured HTTP client
async with (
MCPStreamableHTTPTool(
name="MCP tool",
description="MCP tool description",
url=mcp_server_url,
headers=auth_headers, # Authentication headers
http_client=http_client, # Pass HTTP client with authentication headers
) as mcp_tool,
ChatAgent(
chat_client=OpenAIResponsesClient(),
Expand Down
20 changes: 10 additions & 10 deletions python/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading