|
21 | 21 | from typing import TextIO
|
22 | 22 | from typing import Union
|
23 | 23 |
|
| 24 | +from pydantic import model_validator |
| 25 | +from typing_extensions import override |
| 26 | + |
24 | 27 | from ...agents.readonly_context import ReadonlyContext
|
25 | 28 | from ...auth.auth_credential import AuthCredential
|
26 | 29 | from ...auth.auth_schemes import AuthScheme
|
27 | 30 | from ..base_tool import BaseTool
|
28 | 31 | from ..base_toolset import BaseToolset
|
29 | 32 | from ..base_toolset import ToolPredicate
|
| 33 | +from ..tool_configs import BaseToolConfig |
| 34 | +from ..tool_configs import ToolArgsConfig |
30 | 35 | from .mcp_session_manager import MCPSessionManager
|
31 | 36 | from .mcp_session_manager import retry_on_closed_resource
|
32 | 37 | from .mcp_session_manager import SseConnectionParams
|
@@ -178,3 +183,67 @@ async def close(self) -> None:
|
178 | 183 | except Exception as e:
|
179 | 184 | # Log the error but don't re-raise to avoid blocking shutdown
|
180 | 185 | print(f"Warning: Error during MCPToolset cleanup: {e}", file=self._errlog)
|
| 186 | + |
| 187 | + @override |
| 188 | + @classmethod |
| 189 | + def from_config( |
| 190 | + cls: type[MCPToolset], config: ToolArgsConfig, config_abs_path: str |
| 191 | + ) -> MCPToolset: |
| 192 | + """Creates an MCPToolset from a configuration object.""" |
| 193 | + mcp_toolset_config = MCPToolsetConfig.model_validate(config.model_dump()) |
| 194 | + |
| 195 | + if mcp_toolset_config.stdio_server_params: |
| 196 | + connection_params = mcp_toolset_config.stdio_server_params |
| 197 | + elif mcp_toolset_config.stdio_connection_params: |
| 198 | + connection_params = mcp_toolset_config.stdio_connection_params |
| 199 | + elif mcp_toolset_config.sse_connection_params: |
| 200 | + connection_params = mcp_toolset_config.sse_connection_params |
| 201 | + elif mcp_toolset_config.streamable_http_connection_params: |
| 202 | + connection_params = mcp_toolset_config.streamable_http_connection_params |
| 203 | + else: |
| 204 | + raise ValueError("No connection params found in MCPToolsetConfig.") |
| 205 | + |
| 206 | + return cls( |
| 207 | + connection_params=connection_params, |
| 208 | + tool_filter=mcp_toolset_config.tool_filter, |
| 209 | + auth_scheme=mcp_toolset_config.auth_scheme, |
| 210 | + auth_credential=mcp_toolset_config.auth_credential, |
| 211 | + ) |
| 212 | + |
| 213 | + |
| 214 | +class MCPToolsetConfig(BaseToolConfig): |
| 215 | + """The config for MCPToolset.""" |
| 216 | + |
| 217 | + stdio_server_params: Optional[StdioServerParameters] = None |
| 218 | + |
| 219 | + stdio_connection_params: Optional[StdioConnectionParams] = None |
| 220 | + |
| 221 | + sse_connection_params: Optional[SseConnectionParams] = None |
| 222 | + |
| 223 | + streamable_http_connection_params: Optional[ |
| 224 | + StreamableHTTPConnectionParams |
| 225 | + ] = None |
| 226 | + |
| 227 | + tool_filter: Optional[List[str]] = None |
| 228 | + |
| 229 | + auth_scheme: Optional[AuthScheme] = None |
| 230 | + |
| 231 | + auth_credential: Optional[AuthCredential] = None |
| 232 | + |
| 233 | + @model_validator(mode="after") |
| 234 | + def _check_only_one_params_field(self): |
| 235 | + param_fields = [ |
| 236 | + self.stdio_server_params, |
| 237 | + self.stdio_connection_params, |
| 238 | + self.sse_connection_params, |
| 239 | + self.streamable_http_connection_params, |
| 240 | + ] |
| 241 | + populated_fields = [f for f in param_fields if f is not None] |
| 242 | + |
| 243 | + if len(populated_fields) != 1: |
| 244 | + raise ValueError( |
| 245 | + "Exactly one of stdio_server_params, stdio_connection_params," |
| 246 | + " sse_connection_params, streamable_http_connection_params must be" |
| 247 | + " set." |
| 248 | + ) |
| 249 | + return self |
0 commit comments