Skip to content

Commit eb4c887

Browse files
committed
feat: Add support for custom HTTPX client factory in StreamableHTTPConnectionParams
1 parent 6e834d3 commit eb4c887

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

src/google/adk/tools/mcp_tool/mcp_session_manager.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,17 @@
2727
from typing import Optional
2828
from typing import TextIO
2929
from typing import Union
30+
from typing import runtime_checkable
3031

3132
import anyio
32-
from pydantic import BaseModel
33+
from pydantic import BaseModel, ConfigDict
3334

3435
try:
3536
from mcp import ClientSession
3637
from mcp import StdioServerParameters
3738
from mcp.client.sse import sse_client
3839
from mcp.client.stdio import stdio_client
39-
from mcp.client.streamable_http import streamablehttp_client
40+
from mcp.client.streamable_http import streamablehttp_client, McpHttpClientFactory
4041
except ImportError as e:
4142

4243
if sys.version_info < (3, 10):
@@ -101,11 +102,14 @@ class StreamableHTTPConnectionParams(BaseModel):
101102
when the connection is closed.
102103
"""
103104

105+
model_config = ConfigDict(arbitrary_types_allowed=True, )
106+
104107
url: str
105108
headers: dict[str, Any] | None = None
106109
timeout: float = 5.0
107110
sse_read_timeout: float = 60 * 5.0
108111
terminate_on_close: bool = True
112+
httpx_client_factory: Optional[runtime_checkable(McpHttpClientFactory)] = None
109113

110114

111115
def retry_on_closed_resource(func):
@@ -285,6 +289,7 @@ def _create_client(self, merged_headers: Optional[Dict[str, str]] = None):
285289
seconds=self._connection_params.sse_read_timeout
286290
),
287291
terminate_on_close=self._connection_params.terminate_on_close,
292+
httpx_client_factory=self._connection_params.httpx_client_factory,
288293
)
289294
else:
290295
raise ValueError(

tests/unittests/tools/mcp_tool/test_mcp_session_manager.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,30 @@ def test_init_with_streamable_http_params(self):
144144

145145
assert manager._connection_params == http_params
146146

147+
@pytest.mark.asyncio
148+
async def test_init_with_streamable_http_custom_httpx_factory(self):
149+
"""Test initialization with StreamableHTTPConnectionParams."""
150+
import httpx
151+
custom_httpx_client = httpx.AsyncClient()
152+
153+
def _httpx_factory(headers=None, timeout=None, auth=None):
154+
return custom_httpx_client
155+
156+
custom_httpx_factory = Mock(side_effect=_httpx_factory)
157+
158+
http_params = StreamableHTTPConnectionParams(
159+
url="https://example.com/mcp",
160+
timeout=15.0,
161+
httpx_client_factory=custom_httpx_factory,
162+
)
163+
manager = MCPSessionManager(http_params)
164+
165+
async with manager._create_client():
166+
#assert factory was called
167+
custom_httpx_factory.assert_called_once()
168+
169+
170+
147171
def test_generate_session_key_stdio(self):
148172
"""Test session key generation for stdio connections."""
149173
manager = MCPSessionManager(self.mock_stdio_connection_params)

0 commit comments

Comments
 (0)