Skip to content

Commit cca6102

Browse files
refactor: add abstract profile manager
1 parent aa69c9c commit cca6102

File tree

5 files changed

+217
-2
lines changed

5 files changed

+217
-2
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import asyncio
2+
import logging
3+
from typing import Optional, cast
4+
5+
from mcp import ClientSession, InitializeResult, StdioServerParameters, stdio_client
6+
from mcp.client.sse import sse_client
7+
8+
from mcpm.core.schema import ServerConfig, SSEServerConfig, STDIOServerConfig
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
def _stdio_transport_context(server_config: ServerConfig):
14+
server_config = cast(STDIOServerConfig, server_config)
15+
server_params = StdioServerParameters(command=server_config.command, args=server_config.args, env=server_config.env)
16+
return stdio_client(server_params)
17+
18+
19+
def _sse_transport_context(server_config: ServerConfig):
20+
server_config = cast(SSEServerConfig, server_config)
21+
return sse_client(server_config.url, headers=server_config.headers)
22+
23+
24+
class ServerConnection:
25+
def __init__(self, server_config: ServerConfig) -> None:
26+
self.session: Optional[ClientSession] = None
27+
self.session_initialized_response: Optional[InitializeResult] = None
28+
self._initialized = False
29+
self.server_config = server_config
30+
self._initialized_event = asyncio.Event()
31+
self._shutdown_event = asyncio.Event()
32+
33+
self._transport_context_factory = (
34+
_stdio_transport_context if isinstance(server_config, STDIOServerConfig) else _sse_transport_context
35+
)
36+
37+
self._server_task = asyncio.create_task(self._server_lifespan_cycle())
38+
39+
def healthy(self) -> bool:
40+
return self.session is not None and self._initialized
41+
42+
# block until client session is initialized
43+
async def wait_for_initialization(self):
44+
await self._initialized_event.wait()
45+
46+
# request for client session to gracefully close
47+
async def request_for_shutdown(self):
48+
self._shutdown_event.set()
49+
50+
# block until client session is shutdown
51+
async def wait_for_shutdown_request(self):
52+
await self._shutdown_event.wait()
53+
54+
async def _server_lifespan_cycle(self):
55+
try:
56+
async with self._transport_context_factory(self.server_config) as (read, write):
57+
async with ClientSession(read, write) as session:
58+
self.session_initialized_response = await session.initialize()
59+
60+
self.session = session
61+
self._initialized = True
62+
self._initialized_event.set()
63+
# block here so that the session will not be closed after exit scope
64+
# we could retrieve alive session through self.session
65+
await self.wait_for_shutdown_request()
66+
except Exception as e:
67+
logger.error(f"Failed to connect to server {self.server_config.name}: {e}")
68+
self._initialized_event.set()
69+
self._shutdown_event.set()
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import json
2+
import os
3+
from typing import Dict, Optional, override
4+
5+
from mcpm.core.mcp.client_connection import ServerConnection
6+
from mcpm.core.profile.profile_manager import AbstractProfileManager
7+
from mcpm.core.schema import Profile, ServerConfig
8+
9+
DEFAULT_PROFILE_PATH = os.path.expanduser("~/.config/mcpm/profiles.json")
10+
11+
12+
class LocalProfileManager(AbstractProfileManager):
13+
def __init__(self, profile_path: str = DEFAULT_PROFILE_PATH):
14+
super().__init__()
15+
self.profile_path = profile_path
16+
self._profiles = self._load_profiles()
17+
self.server_connections: Dict[str, Dict[str, ServerConnection]] = {}
18+
19+
def _load_profiles(self) -> Dict[str, Profile]:
20+
if not os.path.exists(self.profile_path):
21+
return {}
22+
with open(self.profile_path, "r") as f:
23+
conf = json.load(f)
24+
return {name: Profile.model_validate(config) for name, config in conf.items()}
25+
26+
def _save_profiles(self) -> None:
27+
with open(self.profile_path, "w") as f:
28+
json.dump(self._profiles, f, indent=2)
29+
30+
def get_profile(self, name: str) -> Optional[Profile]:
31+
return self._profiles.get(name)
32+
33+
def validate_api_key(self, profile_name: str, api_key: str) -> bool:
34+
profile = self.get_profile(profile_name)
35+
if not profile:
36+
return False
37+
return api_key == profile.api_key
38+
39+
def create_profile(self, profile_name: str) -> None:
40+
if profile_name in self._profiles:
41+
return
42+
self._profiles[profile_name] = Profile(name=profile_name, servers=[])
43+
self._save_profiles()
44+
45+
def delete_profile(self, profile_name: str) -> None:
46+
if profile_name not in self._profiles:
47+
return
48+
del self._profiles[profile_name]
49+
self._save_profiles()
50+
51+
def list_profiles(self) -> list[Profile]:
52+
return list(self._profiles.values())
53+
54+
def rename_profile(self, old_name: str, new_name: str) -> None:
55+
if old_name not in self._profiles:
56+
return
57+
if new_name in self._profiles:
58+
return
59+
self._profiles[new_name] = self._profiles.pop(old_name)
60+
self._save_profiles()
61+
62+
def add_server(self, profile_name: str, server_config: ServerConfig) -> None:
63+
if profile_name not in self._profiles:
64+
return
65+
self._profiles[profile_name].servers.append(server_config)
66+
self._save_profiles()
67+
68+
def remove_server(self, profile_name: str, server_name: str) -> None:
69+
if profile_name not in self._profiles:
70+
return
71+
self._profiles[profile_name].servers = [
72+
server for server in self._profiles[profile_name].servers if server.name != server_name
73+
]
74+
self._save_profiles()
75+
76+
def update_server(self, profile_name: str, server_config: ServerConfig) -> None:
77+
if profile_name not in self._profiles:
78+
return
79+
for idx, server in enumerate(self._profiles[profile_name].servers):
80+
if server.name == server_config.name:
81+
self._profiles[profile_name].servers[idx] = server_config
82+
break
83+
self._save_profiles()
84+
85+
@override
86+
async def activate_profile(self, profile_name: str) -> Dict[str, ServerConnection]:
87+
if profile_name not in self._profiles:
88+
return {}
89+
servers = self._profiles[profile_name].servers
90+
self.server_connections[profile_name] = {
91+
server_config.name: ServerConnection(server_config) for server_config in servers
92+
}
93+
return self.server_connections[profile_name]
94+
95+
@override
96+
async def deactivate_profile(self, profile_name: str) -> None:
97+
if profile_name not in self._profiles:
98+
return
99+
if profile_name not in self.server_connections:
100+
return
101+
for _, client in self.server_connections[profile_name].items():
102+
await client.request_for_shutdown()
103+
self.server_connections.pop(profile_name)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from profile import Profile
2+
3+
from mcpm.core.schema import ServerConfig
4+
5+
6+
class AbstractProfileManager:
7+
def __init__(self):
8+
pass
9+
10+
def get_profile(self, name: str) -> Profile | None:
11+
raise NotImplementedError
12+
13+
def validate_api_key(self, profile_name: str, api_key: str) -> bool:
14+
raise NotImplementedError
15+
16+
def create_profile(self, profile_name: str) -> None:
17+
raise NotImplementedError
18+
19+
def delete_profile(self, profile_name: str) -> None:
20+
raise NotImplementedError
21+
22+
def list_profiles(self) -> list[Profile]:
23+
raise NotImplementedError
24+
25+
def rename_profile(self, old_name: str, new_name: str) -> None:
26+
raise NotImplementedError
27+
28+
def add_server(self, profile_name: str, server_config: ServerConfig) -> None:
29+
raise NotImplementedError
30+
31+
def remove_server(self, profile_name: str, server_name: str) -> None:
32+
raise NotImplementedError
33+
34+
def update_server(self, profile_name: str, server_config: ServerConfig) -> None:
35+
raise NotImplementedError
36+
37+
async def activate_profile(self, profile_name: str) -> None:
38+
raise NotImplementedError
39+
40+
async def deactivate_profile(self, profile_name: str) -> None:
41+
raise NotImplementedError

src/mcpm/core/schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,5 +80,5 @@ def to_mcp_proxy_stdio(self) -> STDIOServerConfig:
8080

8181
class Profile(BaseModel):
8282
name: str
83-
api_key: Optional[str]
83+
api_key: Optional[str] = None
8484
servers: list[ServerConfig]

src/mcpm/schemas/server_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,6 @@
44

55
__all__ = ["ServerConfig", "SSEServerConfig", "STDIOServerConfig"]
66

7-
warnings.warn("mcpm.schemas.server_config is deprecated, use mcpm.core.schema instead", DeprecationWarning)
7+
warnings.warn(
8+
"mcpm.schemas.server_config is deprecated, use mcpm.core.schema instead", DeprecationWarning, stacklevel=2
9+
)

0 commit comments

Comments
 (0)