Skip to content

Commit d9ce2e6

Browse files
wuliang229copybara-github
authored andcommitted
feat(config): implement config and from_config for MCPToolset
The connection_params argument in the constructor is split into four arguments in the config class because some of them have identical fields. In order to identify which is which, a separate name is more convenient. PiperOrigin-RevId: 791965995
1 parent dc193f7 commit d9ce2e6

File tree

2 files changed

+85
-1
lines changed

2 files changed

+85
-1
lines changed

src/google/adk/tools/__init__.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import logging
15+
import sys
1516

1617
from ..auth.auth_tool import AuthToolArguments
1718
from .agent_tool import AgentTool
@@ -52,3 +53,17 @@
5253
'ToolContext',
5354
'transfer_to_agent',
5455
]
56+
57+
58+
if sys.version_info < (3, 10):
59+
logger = logging.getLogger('google_adk.' + __name__)
60+
logger.warning(
61+
'MCP requires Python 3.10 or above. Please upgrade your Python'
62+
' version in order to use it.'
63+
)
64+
else:
65+
from .mcp_tool.mcp_toolset import MCPToolset
66+
67+
__all__.extend([
68+
'MCPToolset',
69+
])

src/google/adk/tools/mcp_tool/mcp_toolset.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,17 @@
2121
from typing import TextIO
2222
from typing import Union
2323

24+
from pydantic import model_validator
25+
from typing_extensions import override
26+
2427
from ...agents.readonly_context import ReadonlyContext
2528
from ...auth.auth_credential import AuthCredential
2629
from ...auth.auth_schemes import AuthScheme
2730
from ..base_tool import BaseTool
2831
from ..base_toolset import BaseToolset
2932
from ..base_toolset import ToolPredicate
33+
from ..tool_configs import BaseToolConfig
34+
from ..tool_configs import ToolArgsConfig
3035
from .mcp_session_manager import MCPSessionManager
3136
from .mcp_session_manager import retry_on_closed_resource
3237
from .mcp_session_manager import SseConnectionParams
@@ -178,3 +183,67 @@ async def close(self) -> None:
178183
except Exception as e:
179184
# Log the error but don't re-raise to avoid blocking shutdown
180185
print(f"Warning: Error during MCPToolset cleanup: {e}", file=self._errlog)
186+
187+
@override
188+
@classmethod
189+
def from_config(
190+
cls: type[MCPToolset], config: ToolArgsConfig, config_abs_path: str
191+
) -> MCPToolset:
192+
"""Creates an MCPToolset from a configuration object."""
193+
mcp_toolset_config = MCPToolsetConfig.model_validate(config.model_dump())
194+
195+
if mcp_toolset_config.stdio_server_params:
196+
connection_params = mcp_toolset_config.stdio_server_params
197+
elif mcp_toolset_config.stdio_connection_params:
198+
connection_params = mcp_toolset_config.stdio_connection_params
199+
elif mcp_toolset_config.sse_connection_params:
200+
connection_params = mcp_toolset_config.sse_connection_params
201+
elif mcp_toolset_config.streamable_http_connection_params:
202+
connection_params = mcp_toolset_config.streamable_http_connection_params
203+
else:
204+
raise ValueError("No connection params found in MCPToolsetConfig.")
205+
206+
return cls(
207+
connection_params=connection_params,
208+
tool_filter=mcp_toolset_config.tool_filter,
209+
auth_scheme=mcp_toolset_config.auth_scheme,
210+
auth_credential=mcp_toolset_config.auth_credential,
211+
)
212+
213+
214+
class MCPToolsetConfig(BaseToolConfig):
215+
"""The config for MCPToolset."""
216+
217+
stdio_server_params: Optional[StdioServerParameters] = None
218+
219+
stdio_connection_params: Optional[StdioConnectionParams] = None
220+
221+
sse_connection_params: Optional[SseConnectionParams] = None
222+
223+
streamable_http_connection_params: Optional[
224+
StreamableHTTPConnectionParams
225+
] = None
226+
227+
tool_filter: Optional[List[str]] = None
228+
229+
auth_scheme: Optional[AuthScheme] = None
230+
231+
auth_credential: Optional[AuthCredential] = None
232+
233+
@model_validator(mode="after")
234+
def _check_only_one_params_field(self):
235+
param_fields = [
236+
self.stdio_server_params,
237+
self.stdio_connection_params,
238+
self.sse_connection_params,
239+
self.streamable_http_connection_params,
240+
]
241+
populated_fields = [f for f in param_fields if f is not None]
242+
243+
if len(populated_fields) != 1:
244+
raise ValueError(
245+
"Exactly one of stdio_server_params, stdio_connection_params,"
246+
" sse_connection_params, streamable_http_connection_params must be"
247+
" set."
248+
)
249+
return self

0 commit comments

Comments
 (0)