|
| 1 | +# coding=utf-8 |
| 2 | +import asyncio |
| 3 | +import json |
| 4 | +from typing import List |
| 5 | + |
| 6 | +from langchain_mcp_adapters.client import MultiServerMCPClient |
| 7 | + |
| 8 | +from application.flow.i_step_node import NodeResult |
| 9 | +from application.flow.step_node.mcp_node.i_mcp_node import IMcpNode |
| 10 | + |
| 11 | + |
| 12 | +class BaseMcpNode(IMcpNode): |
| 13 | + def save_context(self, details, workflow_manage): |
| 14 | + self.context['result'] = details.get('result') |
| 15 | + self.context['tool_params'] = details.get('tool_params') |
| 16 | + self.context['mcp_tool'] = details.get('mcp_tool') |
| 17 | + self.answer_text = details.get('result') |
| 18 | + |
| 19 | + def execute(self, mcp_servers, mcp_server, mcp_tool, tool_params, **kwargs) -> NodeResult: |
| 20 | + servers = json.loads(mcp_servers) |
| 21 | + params = self.handle_variables(tool_params) |
| 22 | + |
| 23 | + async def call_tool(s, session, t, a): |
| 24 | + async with MultiServerMCPClient(s) as client: |
| 25 | + s = await client.sessions[session].call_tool(t, a) |
| 26 | + return s |
| 27 | + |
| 28 | + res = asyncio.run(call_tool(servers, mcp_server, mcp_tool, params)) |
| 29 | + return NodeResult({'result': [content.text for content in res.content], 'tool_params': params, 'mcp_tool': mcp_tool}, {}) |
| 30 | + |
| 31 | + def handle_variables(self, tool_params): |
| 32 | + # 处理参数中的变量 |
| 33 | + for k, v in tool_params.items(): |
| 34 | + if type(v) == str: |
| 35 | + tool_params[k] = self.workflow_manage.generate_prompt(tool_params[k]) |
| 36 | + if type(v) == dict: |
| 37 | + self.handle_variables(v) |
| 38 | + return tool_params |
| 39 | + |
| 40 | + def get_reference_content(self, fields: List[str]): |
| 41 | + return str(self.workflow_manage.get_reference_field( |
| 42 | + fields[0], |
| 43 | + fields[1:])) |
| 44 | + |
| 45 | + def get_details(self, index: int, **kwargs): |
| 46 | + return { |
| 47 | + 'name': self.node.properties.get('stepName'), |
| 48 | + "index": index, |
| 49 | + 'run_time': self.context.get('run_time'), |
| 50 | + 'status': self.status, |
| 51 | + 'err_message': self.err_message, |
| 52 | + 'type': self.node.type, |
| 53 | + 'mcp_tool': self.context.get('mcp_tool'), |
| 54 | + 'tool_params': self.context.get('tool_params'), |
| 55 | + 'result': self.context.get('result'), |
| 56 | + } |
0 commit comments