Skip to content

Commit 489a00f

Browse files
committed
mypy fix
1 parent 2ac916a commit 489a00f

File tree

2 files changed

+45
-15
lines changed

2 files changed

+45
-15
lines changed

src/agents/mcp/server.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -649,14 +649,24 @@ def create_streams(
649649
]
650650
]:
651651
"""Create the streams for the server."""
652-
return streamablehttp_client(
653-
url=self.params["url"],
654-
headers=self.params.get("headers", None),
655-
timeout=self.params.get("timeout", 5),
656-
sse_read_timeout=self.params.get("sse_read_timeout", 60 * 5),
657-
terminate_on_close=self.params.get("terminate_on_close", True),
658-
httpx_client_factory=self.params.get("httpx_client_factory", None),
659-
)
652+
# Only pass httpx_client_factory if it's provided
653+
if "httpx_client_factory" in self.params:
654+
return streamablehttp_client(
655+
url=self.params["url"],
656+
headers=self.params.get("headers", None),
657+
timeout=self.params.get("timeout", 5),
658+
sse_read_timeout=self.params.get("sse_read_timeout", 60 * 5),
659+
terminate_on_close=self.params.get("terminate_on_close", True),
660+
httpx_client_factory=self.params["httpx_client_factory"],
661+
)
662+
else:
663+
return streamablehttp_client(
664+
url=self.params["url"],
665+
headers=self.params.get("headers", None),
666+
timeout=self.params.get("timeout", 5),
667+
sse_read_timeout=self.params.get("sse_read_timeout", 60 * 5),
668+
terminate_on_close=self.params.get("terminate_on_close", True),
669+
)
660670

661671
@property
662672
def name(self) -> str:

tests/mcp/test_streamable_http_client_factory.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,19 @@ async def test_default_httpx_client_factory(self):
3636
timeout=10,
3737
sse_read_timeout=300, # Default value
3838
terminate_on_close=True, # Default value
39-
httpx_client_factory=None, # Should be None when not provided
39+
# httpx_client_factory should not be passed when not provided
4040
)
4141

4242
@pytest.mark.asyncio
4343
async def test_custom_httpx_client_factory(self):
4444
"""Test that custom httpx_client_factory is passed correctly."""
4545

4646
# Create a custom factory function
47-
def custom_factory() -> httpx.AsyncClient:
47+
def custom_factory(
48+
headers: dict[str, str] | None = None,
49+
timeout: httpx.Timeout | None = None,
50+
auth: httpx.Auth | None = None,
51+
) -> httpx.AsyncClient:
4852
return httpx.AsyncClient(
4953
verify=False, # Disable SSL verification for testing
5054
timeout=httpx.Timeout(60.0),
@@ -81,7 +85,11 @@ def custom_factory() -> httpx.AsyncClient:
8185
async def test_custom_httpx_client_factory_with_ssl_cert(self):
8286
"""Test custom factory with SSL certificate configuration."""
8387

84-
def ssl_cert_factory() -> httpx.AsyncClient:
88+
def ssl_cert_factory(
89+
headers: dict[str, str] | None = None,
90+
timeout: httpx.Timeout | None = None,
91+
auth: httpx.Auth | None = None,
92+
) -> httpx.AsyncClient:
8593
return httpx.AsyncClient(
8694
verify="/path/to/cert.pem", # Custom SSL certificate
8795
timeout=httpx.Timeout(120.0),
@@ -113,9 +121,13 @@ def ssl_cert_factory() -> httpx.AsyncClient:
113121
async def test_custom_httpx_client_factory_with_proxy(self):
114122
"""Test custom factory with proxy configuration."""
115123

116-
def proxy_factory() -> httpx.AsyncClient:
124+
def proxy_factory(
125+
headers: dict[str, str] | None = None,
126+
timeout: httpx.Timeout | None = None,
127+
auth: httpx.Auth | None = None,
128+
) -> httpx.AsyncClient:
117129
return httpx.AsyncClient(
118-
proxies="http://proxy.example.com:8080",
130+
proxy="http://proxy.example.com:8080",
119131
timeout=httpx.Timeout(60.0),
120132
)
121133

@@ -144,7 +156,11 @@ def proxy_factory() -> httpx.AsyncClient:
144156
async def test_custom_httpx_client_factory_with_retry_logic(self):
145157
"""Test custom factory with retry logic configuration."""
146158

147-
def retry_factory() -> httpx.AsyncClient:
159+
def retry_factory(
160+
headers: dict[str, str] | None = None,
161+
timeout: httpx.Timeout | None = None,
162+
auth: httpx.Auth | None = None,
163+
) -> httpx.AsyncClient:
148164
return httpx.AsyncClient(
149165
timeout=httpx.Timeout(30.0),
150166
# Note: httpx doesn't have built-in retry, but this shows how
@@ -194,7 +210,11 @@ def test_httpx_client_factory_type_annotation(self):
194210
async def test_all_parameters_with_custom_factory(self):
195211
"""Test that all parameters work together with custom factory."""
196212

197-
def comprehensive_factory() -> httpx.AsyncClient:
213+
def comprehensive_factory(
214+
headers: dict[str, str] | None = None,
215+
timeout: httpx.Timeout | None = None,
216+
auth: httpx.Auth | None = None,
217+
) -> httpx.AsyncClient:
198218
return httpx.AsyncClient(
199219
verify=False,
200220
timeout=httpx.Timeout(90.0),

0 commit comments

Comments
 (0)