1+ import asyncio
2+ import os
3+ from types import SimpleNamespace
4+ from typing import List
5+
6+ import yaml
7+ from fastmcp import Client
8+
9+ from .mcp_logging import get_logger
10+ from . import client as _client_mod
11+
12+ _client : Client | None = None
13+ _servers : List [str ] | None = None
14+ SERVER_ROOT = ""
15+ logger = None
16+
17+
18+ class _CallWrapper :
19+ """Wraps a MCP tool so it can be called like a normal Python function."""
20+
21+ def __init__ (self , client : Client , server : str , tool : str , multi : bool ):
22+ self ._client = client
23+ self ._server = server
24+ self ._tool = tool
25+ self ._multi = multi
26+
27+ async def _ensure_client (self ):
28+ global _client , logger
29+ if _client is None :
30+ raise RuntimeError (
31+ "[UltraRAG Error] ToolCall was used before `initialize()` was called."
32+ )
33+ try :
34+ _ = _client .session
35+ except RuntimeError :
36+ await _client .__aenter__ ()
37+ tools = await _client .list_tools ()
38+ tool_name_lst = [
39+ tool .name
40+ for tool in tools
41+ if not tool .name .endswith ("_build" if "_" in tool .name else "build" )
42+ ]
43+ logger .info (f"Available tools: { tool_name_lst } " )
44+
45+ async def _async_call (self , * args , ** kwargs ):
46+ global _client , SERVER_ROOT
47+ await self ._ensure_client ()
48+ concated = f"{ self ._server } _{ self ._tool } " if self ._multi else self ._tool
49+ param_file = os .path .join (SERVER_ROOT , self ._server , "parameter.yaml" )
50+ if os .path .exists (param_file ):
51+ with open (param_file , "r" ) as f :
52+ parameter = yaml .safe_load (f )
53+ else :
54+ parameter = {}
55+
56+ with open (os .path .join (SERVER_ROOT , self ._server , "server.yaml" ), "r" ) as f :
57+ try :
58+ input_param = yaml .safe_load (f )["tools" ][self ._tool ]["input" ]
59+ input_keys = list (input_param .keys ())
60+ except :
61+ raise ValueError (
62+ f"[UltraRAG Error] Tool { self ._tool } not found in server { self ._server } configuration!"
63+ )
64+
65+ for k , v in list (input_param .items ()):
66+ if isinstance (v , str ) and v .startswith ("$" ):
67+ key = v [1 :]
68+ if key not in parameter :
69+ continue
70+ input_param [k ] = parameter [key ]
71+ else :
72+ input_param [k ] = None
73+
74+ if len (args ) > len (input_param ):
75+ raise ValueError (
76+ f"[UltraRAG Error] Expected at most { len (input_param )} positional args, got { len (args )} "
77+ )
78+ for pos , value in enumerate (args ):
79+ key = input_keys [pos ]
80+ input_param [key ] = value
81+
82+ for k , v in kwargs .items ():
83+ if k not in input_param :
84+ raise ValueError (f"[UltraRAG Error] Unexpected keyword arg: { k !r} " )
85+ input_param [k ] = v
86+
87+ missing = [k for k , v in input_param .items () if v is None ]
88+ if missing :
89+ raise ValueError (f"[UltraRAG Error] Missing value for key(s): { missing } " )
90+ result = await self ._client .call_tool (concated , input_param )
91+ return result .data if result else None
92+
93+ def __call__ (self , * args , ** kwargs ):
94+ # return asyncio.run(self._async_call(*args, **kwargs))
95+ loop = asyncio .get_event_loop_policy ().get_event_loop ()
96+ if loop .is_running ():
97+ return loop .create_task (self ._async_call (* args , ** kwargs ))
98+ else :
99+ return loop .run_until_complete (self ._async_call (* args , ** kwargs ))
100+
101+
102+ class _ServerProxy (SimpleNamespace ):
103+ """
104+ Proxy for a specific server, e.g. `ToolCall.retriever`.
105+
106+ Accessing an attribute on this object (e.g. `.retriever_search`) returns a
107+ `_CallWrapper` bound to that (server, tool) pair.
108+ """
109+ def __init__ (self , client : Client , name : str , multi : bool ):
110+ super ().__init__ ()
111+ self ._client = client
112+ self ._name = name
113+ self ._multi = multi
114+
115+ def __getattr__ (self , tool_name : str ):
116+ return _CallWrapper (self ._client , self ._name , tool_name , self ._multi )
117+
118+
119+ class _Router (SimpleNamespace ):
120+ """
121+ Top-level router for ToolCall.
122+
123+ Example:
124+ ToolCall.retriever.retriever_search(...)
125+ ToolCall.benchmark.get_data(...)
126+ """
127+ def __getattr__ (self , server : str ):
128+ global _client , _servers
129+ if server not in _servers :
130+ raise AttributeError (f"Server { server } has not been initialized!" )
131+ return _ServerProxy (_client , server , len (_servers ) > 1 )
132+
133+
134+ def initialize (servers : list [str ], server_root : str , log_level = "info" ):
135+ """
136+ Initialize MCP servers so they can be accessed via ToolCall.
137+ """
138+ global _client , _servers , SERVER_ROOT , logger
139+ logger = get_logger ("Client" , log_level )
140+ SERVER_ROOT = server_root
141+ mcp_cfg = {"mcpServers" : {}}
142+ for server_name in servers :
143+ path = os .path .join (server_root , server_name , "src" , f"{ server_name } .py" )
144+ if not os .path .exists (path ):
145+ raise ValueError (f"Server path { path } does not exist!" )
146+ mcp_cfg ["mcpServers" ][server_name ] = {
147+ "command" : "python" ,
148+ "args" : [path ],
149+ "env" : os .environ .copy (),
150+ }
151+
152+ _client = Client (mcp_cfg )
153+ _servers = servers
154+
155+
156+ ToolCall = _Router ()
157+
158+
159+ async def _pipeline_async (
160+ pipeline_file : str ,
161+ parameter_file : str ,
162+ log_level : str = "error" ,
163+ ):
164+ """
165+ Internal async helper that runs a full UltraRAG pipeline with
166+ an explicitly provided parameter file.
167+ """
168+ _client_mod .logger = get_logger ("Client" , log_level )
169+
170+ return await _client_mod .run (pipeline_file , parameter_file , return_all = True )
171+
172+
173+ def PipelineCall (
174+ pipeline_file : str ,
175+ parameter_file : str ,
176+ log_level : str = "error" ,
177+ ):
178+ """
179+ Run a full UltraRAG pipeline from Python, similar to `ultrarag run`,
180+ but with an explicitly provided parameter file.
181+ """
182+ loop = asyncio .get_event_loop_policy ().get_event_loop ()
183+ if loop .is_running ():
184+ return loop .create_task (
185+ _pipeline_async (pipeline_file , parameter_file , log_level )
186+ )
187+ else :
188+ return loop .run_until_complete (
189+ _pipeline_async (pipeline_file , parameter_file , log_level )
190+ )
0 commit comments