Skip to content

Commit 1e59033

Browse files
authored
typing: fix basic typing issues (#247)
1 parent 7ff9929 commit 1e59033

File tree

1 file changed

+45
-60
lines changed

1 file changed

+45
-60
lines changed

langchain_mcp_adapters/sessions.py

Lines changed: 45 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -44,24 +44,29 @@ class StdioConnection(TypedDict):
4444
args: list[str]
4545
"""Command line arguments to pass to the executable."""
4646

47-
env: dict[str, str] | None
47+
env: NotRequired[dict[str, str] | None]
4848
"""The environment to use when spawning the process."""
4949

50-
cwd: str | Path | None
50+
cwd: NotRequired[str | Path | None]
5151
"""The working directory to use when spawning the process."""
5252

53-
encoding: str
54-
"""The text encoding used when sending/receiving messages to the server."""
53+
encoding: NotRequired[str]
54+
"""The text encoding used when sending/receiving messages to the server.
5555
56-
encoding_error_handler: EncodingErrorHandler
56+
Default is 'utf-8'.
57+
"""
58+
59+
encoding_error_handler: NotRequired[EncodingErrorHandler]
5760
"""
5861
The text encoding error handler.
5962
6063
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
6164
explanations of possible values.
65+
66+
Default is 'strict', which raises an error on encoding/decoding errors.
6267
"""
6368

64-
session_kwargs: dict[str, Any] | None
69+
session_kwargs: NotRequired[dict[str, Any] | None]
6570
"""Additional keyword arguments to pass to the ClientSession."""
6671

6772

@@ -71,19 +76,27 @@ class SSEConnection(TypedDict):
7176
url: str
7277
"""The URL of the SSE endpoint to connect to."""
7378

74-
headers: dict[str, Any] | None
79+
headers: NotRequired[dict[str, Any] | None]
7580
"""HTTP headers to send to the SSE endpoint."""
7681

77-
timeout: float
78-
"""HTTP timeout."""
82+
timeout: NotRequired[float]
83+
"""HTTP timeout.
84+
85+
Default is 5 seconds. If the server takes longer to respond,
86+
you can increase this value.
87+
"""
88+
89+
sse_read_timeout: NotRequired[float]
90+
"""SSE read timeout.
7991
80-
sse_read_timeout: float
81-
"""SSE read timeout."""
92+
Default is 300 seconds (5 minutes). This is how long the client will
93+
wait for a new event before disconnecting.
94+
"""
8295

83-
session_kwargs: dict[str, Any] | None
96+
session_kwargs: NotRequired[dict[str, Any] | None]
8497
"""Additional keyword arguments to pass to the ClientSession."""
8598

86-
httpx_client_factory: McpHttpClientFactory | None
99+
httpx_client_factory: NotRequired[McpHttpClientFactory | None]
87100
"""Custom factory for httpx.AsyncClient (optional)."""
88101

89102
auth: NotRequired[httpx.Auth]
@@ -92,27 +105,28 @@ class SSEConnection(TypedDict):
92105

93106
class StreamableHttpConnection(TypedDict):
94107
transport: Literal["streamable_http"]
108+
"""Connection configuration for Streamable HTTP transport."""
95109

96110
url: str
97111
"""The URL of the endpoint to connect to."""
98112

99-
headers: dict[str, Any] | None
113+
headers: NotRequired[dict[str, Any] | None]
100114
"""HTTP headers to send to the endpoint."""
101115

102-
timeout: timedelta
116+
timeout: NotRequired[timedelta]
103117
"""HTTP timeout."""
104118

105-
sse_read_timeout: timedelta
119+
sse_read_timeout: NotRequired[timedelta]
106120
"""How long (in seconds) the client will wait for a new event before disconnecting.
107121
All other HTTP operations are controlled by `timeout`."""
108122

109-
terminate_on_close: bool
123+
terminate_on_close: NotRequired[bool]
110124
"""Whether to terminate the session on close."""
111125

112-
session_kwargs: dict[str, Any] | None
126+
session_kwargs: NotRequired[dict[str, Any] | None]
113127
"""Additional keyword arguments to pass to the ClientSession."""
114128

115-
httpx_client_factory: McpHttpClientFactory | None
129+
httpx_client_factory: NotRequired[McpHttpClientFactory | None]
116130
"""Custom factory for httpx.AsyncClient (optional)."""
117131

118132
auth: NotRequired[httpx.Auth]
@@ -125,7 +139,7 @@ class WebsocketConnection(TypedDict):
125139
url: str
126140
"""The URL of the Websocket endpoint to connect to."""
127141

128-
session_kwargs: dict[str, Any] | None
142+
session_kwargs: NotRequired[dict[str, Any] | None]
129143
"""Additional keyword arguments to pass to the ClientSession"""
130144

131145

@@ -291,7 +305,6 @@ async def create_session(connection: Connection) -> AsyncIterator[ClientSession]
291305
292306
Yields:
293307
A ClientSession
294-
295308
"""
296309
if "transport" not in connection:
297310
raise ValueError(
@@ -302,57 +315,29 @@ async def create_session(connection: Connection) -> AsyncIterator[ClientSession]
302315
)
303316

304317
transport = connection["transport"]
318+
params = {k: v for k, v in connection.items() if k != "transport"}
319+
305320
if transport == "sse":
306-
if "url" not in connection:
321+
if "url" not in params:
307322
raise ValueError("'url' parameter is required for SSE connection")
308-
async with _create_sse_session(
309-
url=connection["url"],
310-
headers=connection.get("headers"),
311-
timeout=connection.get("timeout", DEFAULT_HTTP_TIMEOUT),
312-
sse_read_timeout=connection.get("sse_read_timeout", DEFAULT_SSE_READ_TIMEOUT),
313-
session_kwargs=connection.get("session_kwargs"),
314-
httpx_client_factory=connection.get("httpx_client_factory"),
315-
auth=connection.get("auth"),
316-
) as session:
323+
async with _create_sse_session(**params) as session:
317324
yield session
318325
elif transport == "streamable_http":
319-
if "url" not in connection:
326+
if "url" not in params:
320327
raise ValueError("'url' parameter is required for Streamable HTTP connection")
321-
async with _create_streamable_http_session(
322-
url=connection["url"],
323-
headers=connection.get("headers"),
324-
timeout=connection.get("timeout", DEFAULT_STREAMABLE_HTTP_TIMEOUT),
325-
sse_read_timeout=connection.get(
326-
"sse_read_timeout", DEFAULT_STREAMABLE_HTTP_SSE_READ_TIMEOUT
327-
),
328-
session_kwargs=connection.get("session_kwargs"),
329-
httpx_client_factory=connection.get("httpx_client_factory"),
330-
auth=connection.get("auth"),
331-
) as session:
328+
async with _create_streamable_http_session(**params) as session:
332329
yield session
333330
elif transport == "stdio":
334-
if "command" not in connection:
331+
if "command" not in params:
335332
raise ValueError("'command' parameter is required for stdio connection")
336-
if "args" not in connection:
333+
if "args" not in params:
337334
raise ValueError("'args' parameter is required for stdio connection")
338-
async with _create_stdio_session(
339-
command=connection["command"],
340-
args=connection["args"],
341-
env=connection.get("env"),
342-
cwd=connection.get("cwd"),
343-
encoding=connection.get("encoding", DEFAULT_ENCODING),
344-
encoding_error_handler=connection.get(
345-
"encoding_error_handler", DEFAULT_ENCODING_ERROR_HANDLER
346-
),
347-
session_kwargs=connection.get("session_kwargs"),
348-
) as session:
335+
async with _create_stdio_session(**params) as session:
349336
yield session
350337
elif transport == "websocket":
351-
if "url" not in connection:
338+
if "url" not in params:
352339
raise ValueError("'url' parameter is required for Websocket connection")
353-
async with _create_websocket_session(
354-
url=connection["url"], session_kwargs=connection.get("session_kwargs")
355-
) as session:
340+
async with _create_websocket_session(**params) as session:
356341
yield session
357342
else:
358343
raise ValueError(

0 commit comments

Comments
 (0)