Skip to content

Commit ceb6a82

Browse files
authored
more linting (#249)
More linting
1 parent 876c709 commit ceb6a82

File tree

13 files changed

+181
-117
lines changed

13 files changed

+181
-117
lines changed

langchain_mcp_adapters/client.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,10 @@ def __init__(self, connections: dict[str, Connection] | None = None) -> None:
8282

8383
@asynccontextmanager
8484
async def session(
85-
self, server_name: str, *, auto_initialize: bool = True
85+
self,
86+
server_name: str,
87+
*,
88+
auto_initialize: bool = True,
8689
) -> AsyncIterator[ClientSession]:
8790
"""Connect to an MCP server and initialize a session.
8891
@@ -136,15 +139,21 @@ async def get_tools(self, *, server_name: str | None = None) -> list[BaseTool]:
136139
return all_tools
137140

138141
async def get_prompt(
139-
self, server_name: str, prompt_name: str, *, arguments: dict[str, Any] | None = None
142+
self,
143+
server_name: str,
144+
prompt_name: str,
145+
*,
146+
arguments: dict[str, Any] | None = None,
140147
) -> list[HumanMessage | AIMessage]:
141148
"""Get a prompt from a given MCP server."""
142149
async with self.session(server_name) as session:
143-
prompt = await load_mcp_prompt(session, prompt_name, arguments=arguments)
144-
return prompt
150+
return await load_mcp_prompt(session, prompt_name, arguments=arguments)
145151

146152
async def get_resources(
147-
self, server_name: str, *, uris: str | list[str] | None = None
153+
self,
154+
server_name: str,
155+
*,
156+
uris: str | list[str] | None = None,
148157
) -> list[Blob]:
149158
"""Get resources from a given MCP server.
150159
@@ -157,8 +166,7 @@ async def get_resources(
157166
158167
"""
159168
async with self.session(server_name) as session:
160-
resources = await load_mcp_resources(session, uris=uris)
161-
return resources
169+
return await load_mcp_resources(session, uris=uris)
162170

163171
async def __aenter__(self) -> "MultiServerMCPClient":
164172
raise NotImplementedError(ASYNC_CONTEXT_MANAGER_ERROR)

langchain_mcp_adapters/prompts.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ def convert_mcp_prompt_message_to_langchain_message(
3030

3131

3232
async def load_mcp_prompt(
33-
session: ClientSession, name: str, *, arguments: dict[str, Any] | None = None
33+
session: ClientSession,
34+
name: str,
35+
*,
36+
arguments: dict[str, Any] | None = None,
3437
) -> list[HumanMessage | AIMessage]:
3538
"""Load MCP prompt and convert to LangChain messages."""
3639
response = await session.get_prompt(name, arguments)

langchain_mcp_adapters/resources.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ async def get_mcp_resource(session: ClientSession, uri: str) -> list[Blob]:
4848

4949

5050
async def load_mcp_resources(
51-
session: ClientSession, *, uris: str | list[str] | None = None
51+
session: ClientSession,
52+
*,
53+
uris: str | list[str] | None = None,
5254
) -> list[Blob]:
5355
"""Load MCP resources and convert them to LangChain Blobs.
5456
@@ -74,12 +76,14 @@ async def load_mcp_resources(
7476
else:
7577
uri_list = uris
7678

77-
for uri in uri_list:
78-
try:
79+
current_uri = None
80+
try:
81+
for uri in uri_list:
82+
current_uri = uri
7983
resource_blobs = await get_mcp_resource(session, uri)
8084
blobs.extend(resource_blobs)
81-
except Exception as e:
82-
msg = f"Error fetching resource {uri}"
83-
raise RuntimeError(msg) from e
85+
except Exception as e:
86+
msg = f"Error fetching resource {current_uri}"
87+
raise RuntimeError(msg) from e
8488

8589
return blobs

langchain_mcp_adapters/sessions.py

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
from __future__ import annotations
22

33
import os
4-
from collections.abc import AsyncIterator
54
from contextlib import asynccontextmanager
65
from datetime import timedelta
7-
from pathlib import Path
8-
from typing import Any, Literal, Protocol
6+
from typing import TYPE_CHECKING, Any, Literal, Protocol
97

10-
import httpx
118
from mcp import ClientSession, StdioServerParameters
129
from mcp.client.sse import sse_client
1310
from mcp.client.stdio import stdio_client
1411
from mcp.client.streamable_http import streamablehttp_client
1512
from typing_extensions import NotRequired, TypedDict
1613

14+
if TYPE_CHECKING:
15+
from collections.abc import AsyncIterator
16+
from pathlib import Path
17+
18+
import httpx
19+
1720
EncodingErrorHandler = Literal["strict", "ignore", "replace"]
1821

1922
DEFAULT_ENCODING = "utf-8"
@@ -147,7 +150,7 @@ class WebsocketConnection(TypedDict):
147150

148151

149152
@asynccontextmanager
150-
async def _create_stdio_session(
153+
async def _create_stdio_session( # noqa: PLR0913
151154
*,
152155
command: str,
153156
args: list[str],
@@ -186,13 +189,15 @@ async def _create_stdio_session(
186189
)
187190

188191
# Create and store the connection
189-
async with stdio_client(server_params) as (read, write):
190-
async with ClientSession(read, write, **(session_kwargs or {})) as session:
191-
yield session
192+
async with (
193+
stdio_client(server_params) as (read, write),
194+
ClientSession(read, write, **(session_kwargs or {})) as session,
195+
):
196+
yield session
192197

193198

194199
@asynccontextmanager
195-
async def _create_sse_session(
200+
async def _create_sse_session( # noqa: PLR0913
196201
*,
197202
url: str,
198203
headers: dict[str, Any] | None = None,
@@ -219,16 +224,18 @@ async def _create_sse_session(
219224
if httpx_client_factory is not None:
220225
kwargs["httpx_client_factory"] = httpx_client_factory
221226

222-
async with sse_client(url, headers, timeout, sse_read_timeout, auth=auth, **kwargs) as (
223-
read,
224-
write,
227+
async with (
228+
sse_client(url, headers, timeout, sse_read_timeout, auth=auth, **kwargs) as (
229+
read,
230+
write,
231+
),
232+
ClientSession(read, write, **(session_kwargs or {})) as session,
225233
):
226-
async with ClientSession(read, write, **(session_kwargs or {})) as session:
227-
yield session
234+
yield session
228235

229236

230237
@asynccontextmanager
231-
async def _create_streamable_http_session(
238+
async def _create_streamable_http_session( # noqa: PLR0913
232239
*,
233240
url: str,
234241
headers: dict[str, Any] | None = None,
@@ -257,16 +264,26 @@ async def _create_streamable_http_session(
257264
if httpx_client_factory is not None:
258265
kwargs["httpx_client_factory"] = httpx_client_factory
259266

260-
async with streamablehttp_client(
261-
url, headers, timeout, sse_read_timeout, terminate_on_close, auth=auth, **kwargs
262-
) as (read, write, _):
263-
async with ClientSession(read, write, **(session_kwargs or {})) as session:
264-
yield session
267+
async with (
268+
streamablehttp_client(
269+
url,
270+
headers,
271+
timeout,
272+
sse_read_timeout,
273+
terminate_on_close,
274+
auth=auth,
275+
**kwargs,
276+
) as (read, write, _),
277+
ClientSession(read, write, **(session_kwargs or {})) as session,
278+
):
279+
yield session
265280

266281

267282
@asynccontextmanager
268283
async def _create_websocket_session(
269-
*, url: str, session_kwargs: dict[str, Any] | None = None
284+
*,
285+
url: str,
286+
session_kwargs: dict[str, Any] | None = None,
270287
) -> AsyncIterator[ClientSession]:
271288
"""Create a new session to an MCP server using Websockets.
272289
@@ -288,13 +305,15 @@ async def _create_websocket_session(
288305
)
289306
raise ImportError(msg) from None
290307

291-
async with websocket_client(url) as (read, write):
292-
async with ClientSession(read, write, **(session_kwargs or {})) as session:
293-
yield session
308+
async with (
309+
websocket_client(url) as (read, write),
310+
ClientSession(read, write, **(session_kwargs or {})) as session,
311+
):
312+
yield session
294313

295314

296315
@asynccontextmanager
297-
async def create_session(connection: Connection) -> AsyncIterator[ClientSession]:
316+
async def create_session(connection: Connection) -> AsyncIterator[ClientSession]: # noqa: C901
298317
"""Create a new session to an MCP server.
299318
300319
Args:

langchain_mcp_adapters/tools.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,10 @@ async def _list_all_tools(session: ClientSession) -> list[MCPTool]:
6363

6464

6565
def convert_mcp_tool_to_langchain_tool(
66-
session: ClientSession | None, tool: MCPTool, *, connection: Connection | None = None
66+
session: ClientSession | None,
67+
tool: MCPTool,
68+
*,
69+
connection: Connection | None = None,
6770
) -> BaseTool:
6871
"""Convert an MCP tool to a LangChain tool.
6972
@@ -91,7 +94,8 @@ async def call_tool(
9194
async with create_session(connection) as tool_session:
9295
await tool_session.initialize()
9396
call_tool_result = await cast("ClientSession", tool_session).call_tool(
94-
tool.name, arguments
97+
tool.name,
98+
arguments,
9599
)
96100
else:
97101
call_tool_result = await session.call_tool(tool.name, arguments)
@@ -108,7 +112,9 @@ async def call_tool(
108112

109113

110114
async def load_mcp_tools(
111-
session: ClientSession | None, *, connection: Connection | None = None
115+
session: ClientSession | None,
116+
*,
117+
connection: Connection | None = None,
112118
) -> list[BaseTool]:
113119
"""Load all available MCP tools and convert them to LangChain tools.
114120
@@ -129,10 +135,9 @@ async def load_mcp_tools(
129135
else:
130136
tools = await _list_all_tools(session)
131137

132-
converted_tools = [
138+
return [
133139
convert_mcp_tool_to_langchain_tool(session, tool, connection=connection) for tool in tools
134140
]
135-
return converted_tools
136141

137142

138143
def _get_injected_args(tool: BaseTool) -> list[str]:
@@ -143,12 +148,11 @@ def _is_injected_arg_type(type_: type) -> bool:
143148
for arg in get_args(type_)[1:]
144149
)
145150

146-
injected_args = [
151+
return [
147152
field
148153
for field, field_info in get_all_basemodel_annotations(tool.args_schema).items()
149154
if _is_injected_arg_type(field_info)
150155
]
151-
return injected_args
152156

153157

154158
def to_fastmcp(tool: BaseTool) -> FastMCPTool:
@@ -168,19 +172,21 @@ def to_fastmcp(tool: BaseTool) -> FastMCPTool:
168172
arg_model = create_model(f"{tool.name}Arguments", **field_definitions, __base__=ArgModelBase)
169173
fn_metadata = FuncMetadata(arg_model=arg_model)
170174

171-
async def fn(**arguments: dict[str, Any]) -> Any:
175+
# We'll use an Any type for the function return type.
176+
# We're providing the parameters separately
177+
async def fn(**arguments: dict[str, Any]) -> Any: # noqa: ANN401
172178
return await tool.ainvoke(arguments)
173179

174180
injected_args = _get_injected_args(tool)
175181
if len(injected_args) > 0:
176-
raise NotImplementedError("LangChain tools with injected arguments are not supported")
182+
msg = "LangChain tools with injected arguments are not supported"
183+
raise NotImplementedError(msg)
177184

178-
fastmcp_tool = FastMCPTool(
185+
return FastMCPTool(
179186
fn=fn,
180187
name=tool.name,
181188
description=tool.description,
182189
parameters=parameters,
183190
fn_metadata=fn_metadata,
184191
is_async=True,
185192
)
186-
return fastmcp_tool

pyproject.toml

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,23 +52,12 @@ target-version = "py310"
5252
select = [ "ALL",]
5353
ignore = [
5454
"E501", # line-length
55-
"ARG",
56-
"ANN401",
57-
"C901",
58-
"COM812",
55+
"COM812", # conflict with formatter
5956
"D100",
6057
"D101",
6158
"D102",
6259
"D104",
6360
"D105",
64-
"EM101",
65-
"EM102",
66-
"PERF203",
67-
"PLR0913",
68-
"RET504",
69-
"SIM117",
70-
"TC002",
71-
"TC003",
7261
]
7362

7463

tests/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ def websocket_server_port() -> int:
1919
@pytest.fixture
2020
def websocket_server(websocket_server_port: int) -> Generator[None, None, None]:
2121
proc = multiprocessing.Process(
22-
target=run_server, kwargs={"server_port": websocket_server_port}, daemon=True
22+
target=run_server,
23+
kwargs={"server_port": websocket_server_port},
24+
daemon=True,
2325
)
2426
proc.start()
2527

tests/servers/math_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def configure_assistant(skills: str) -> list[dict]:
2121
{
2222
"role": "assistant",
2323
"content": f"You are a helpful assistant. You have the following skills: {skills}. Always use only one tool at a time.",
24-
}
24+
},
2525
]
2626

2727

0 commit comments

Comments
 (0)