Skip to content

Commit 628a6b1

Browse files
Added property resource_templates and reas_resource that were present in session but missing in session_group.
TESTED=unit tests
1 parent b0b44c2 commit 628a6b1

File tree

2 files changed

+104
-13
lines changed

2 files changed

+104
-13
lines changed

src/mcp/client/session_group.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from typing import Any, TypeAlias
1717

1818
import anyio
19-
from pydantic import BaseModel
19+
from pydantic import BaseModel, AnyUrl
2020
from typing_extensions import Self
2121

2222
import mcp
@@ -100,6 +100,7 @@ class _ComponentNames(BaseModel):
100100
# Client-server connection management.
101101
_sessions: dict[mcp.ClientSession, _ComponentNames]
102102
_tool_to_session: dict[str, mcp.ClientSession]
103+
_resource_to_session: dict[str, mcp.ClientSession]
103104
_exit_stack: contextlib.AsyncExitStack
104105
_session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack]
105106

@@ -116,20 +117,16 @@ def __init__(
116117
) -> None:
117118
"""Initializes the MCP client."""
118119

119-
self._tools = {}
120-
self._resources = {}
120+
self._exit_stack = exit_stack or contextlib.AsyncExitStack()
121+
self._owns_exit_stack = exit_stack is None
122+
self._session_exit_stacks = {}
123+
self._component_name_hook = component_name_hook
121124
self._prompts = {}
122-
125+
self._resources = {}
126+
self._tools = {}
123127
self._sessions = {}
124128
self._tool_to_session = {}
125-
if exit_stack is None:
126-
self._exit_stack = contextlib.AsyncExitStack()
127-
self._owns_exit_stack = True
128-
else:
129-
self._exit_stack = exit_stack
130-
self._owns_exit_stack = False
131-
self._session_exit_stacks = {}
132-
self._component_name_hook = component_name_hook
129+
self._resource_to_session = {} # New mapping
133130

134131
async def __aenter__(self) -> Self:
135132
# Enter the exit stack only if we created it ourselves
@@ -174,6 +171,16 @@ def tools(self) -> dict[str, types.Tool]:
174171
"""Returns the tools as a dictionary of names to tools."""
175172
return self._tools
176173

174+
@property
175+
def resource_templates(self) -> list[types.ResourceTemplate]:
176+
"""Return all unique resource templates from the resources."""
177+
templates: list[types.ResourceTemplate] = []
178+
for r in self._resources.values():
179+
t = getattr(r, "template", None)
180+
if t is not None and t not in templates:
181+
templates.append(t)
182+
return templates
183+
177184
async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult:
178185
"""Executes a tool given its name and arguments."""
179186
session = self._tool_to_session[name]
@@ -296,8 +303,8 @@ async def _aggregate_components(
296303
resources_temp: dict[str, types.Resource] = {}
297304
tools_temp: dict[str, types.Tool] = {}
298305
tool_to_session_temp: dict[str, mcp.ClientSession] = {}
306+
resource_to_session_temp: dict[str, mcp.ClientSession] = {}
299307

300-
# Query the server for its prompts and aggregate to list.
301308
try:
302309
prompts = (await session.list_prompts()).prompts
303310
for prompt in prompts:
@@ -314,6 +321,7 @@ async def _aggregate_components(
314321
name = self._component_name(resource.name, server_info)
315322
resources_temp[name] = resource
316323
component_names.resources.add(name)
324+
resource_to_session_temp[name] = session
317325
except McpError as err:
318326
logging.warning(f"Could not fetch resources: {err}")
319327

@@ -365,8 +373,20 @@ async def _aggregate_components(
365373
self._resources.update(resources_temp)
366374
self._tools.update(tools_temp)
367375
self._tool_to_session.update(tool_to_session_temp)
376+
self._resource_to_session.update(resource_to_session_temp)
368377

369378
def _component_name(self, name: str, server_info: types.Implementation) -> str:
370379
if self._component_name_hook:
371380
return self._component_name_hook(name, server_info)
372381
return name
382+
383+
async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
384+
"""Read a resource from the appropriate session based on the URI."""
385+
print(self._resources)
386+
print(self._resource_to_session)
387+
for name, resource in self._resources.items():
388+
if resource.uri == uri:
389+
session = self._resource_to_session.get(name)
390+
if session:
391+
return await session.read_resource(uri)
392+
raise ValueError(f"Resource not found: {uri}")

tests/client/test_session_group.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from unittest import mock
33

44
import pytest
5+
from pydantic import AnyUrl
56

67
import mcp
78
from mcp import types
@@ -395,3 +396,73 @@ async def test_establish_session_parameterized(
395396
# 3. Assert returned values
396397
assert returned_server_info is mock_initialize_result.serverInfo
397398
assert returned_session is mock_entered_session
399+
400+
@pytest.mark.anyio
401+
async def test_read_resource_not_found(self):
402+
"""Test reading a non-existent resource from a session group."""
403+
# --- Mock Dependencies ---
404+
mock_session = mock.AsyncMock(spec=mcp.ClientSession)
405+
test_resource = types.Resource(
406+
name="test_resource",
407+
uri=AnyUrl("test://resource/1"),
408+
description="Test resource"
409+
)
410+
411+
# Mock all list methods
412+
mock_session.list_resources.return_value = types.ListResourcesResult(resources=[test_resource])
413+
mock_session.list_prompts.return_value = types.ListPromptsResult(prompts=[])
414+
mock_session.list_tools.return_value = types.ListToolsResult(tools=[])
415+
416+
# --- Test Setup ---
417+
group = ClientSessionGroup()
418+
group._session_exit_stacks[mock_session] = mock.AsyncMock(spec=contextlib.AsyncExitStack)
419+
await group.connect_with_session(
420+
types.Implementation(name="test_server", version="1.0.0"),
421+
mock_session
422+
)
423+
424+
# --- Test Execution & Assertions ---
425+
with pytest.raises(ValueError, match="Resource not found: test://nonexistent"):
426+
await group.read_resource(AnyUrl("test://nonexistent"))
427+
428+
@pytest.mark.anyio
429+
async def test_read_resource_success(self):
430+
"""Test successfully reading a resource from a session group."""
431+
# --- Mock Dependencies ---
432+
mock_session = mock.AsyncMock(spec=mcp.ClientSession)
433+
test_resource = types.Resource(
434+
name="test_resource",
435+
uri=AnyUrl("test://resource/1"),
436+
description="Test resource"
437+
)
438+
439+
# Mock all list methods
440+
mock_session.list_resources.return_value = types.ListResourcesResult(resources=[test_resource])
441+
mock_session.list_prompts.return_value = types.ListPromptsResult(prompts=[])
442+
mock_session.list_tools.return_value = types.ListToolsResult(tools=[])
443+
444+
# Mock the session's read_resource method
445+
mock_read_result = mock.AsyncMock(spec=types.ReadResourceResult)
446+
mock_read_result.content = [types.TextContent(type="text", text="Resource content")]
447+
mock_session.read_resource.return_value = mock_read_result
448+
449+
# --- Test Setup ---
450+
group = ClientSessionGroup()
451+
group._session_exit_stacks[mock_session] = mock.AsyncMock(spec=contextlib.AsyncExitStack)
452+
await group.connect_with_session(
453+
types.Implementation(name="test_server", version="1.0.0"),
454+
mock_session
455+
)
456+
457+
# Verify resource was added
458+
assert "test_resource" in group._resources
459+
assert group._resources["test_resource"] == test_resource
460+
assert "test_resource" in group._resource_to_session
461+
assert group._resource_to_session["test_resource"] == mock_session
462+
463+
# --- Test Execution ---
464+
result = await group.read_resource(AnyUrl("test://resource/1"))
465+
466+
# --- Assertions ---
467+
assert result.content == [types.TextContent(type="text", text="Resource content")]
468+
mock_session.read_resource.assert_called_once_with(AnyUrl("test://resource/1"))

0 commit comments

Comments
 (0)