Skip to content

Commit 8b4aa2a

Browse files
authored
feat: add toolcall & pipelinecall for local UltraRAG integration (#143)
1 parent 9ad95b3 commit 8b4aa2a

File tree

3 files changed

+259
-137
lines changed

3 files changed

+259
-137
lines changed

script/api_usage_example.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Example for ToolCall usage with benchmark and retriever servers
2+
3+
from ultrarag.api import initialize, ToolCall
4+
5+
6+
initialize(["benchmark", "retriever"], server_root="servers")
7+
8+
benchmark_param_dict = {
9+
"key_map":{
10+
"gt_ls": "golden_answers",
11+
"q_ls": "question"
12+
},
13+
"limit": -1,
14+
"seed": 42,
15+
"name": "nq",
16+
"path": "data/sample_nq_10.jsonl",
17+
18+
}
19+
benchmark = ToolCall.benchmark.get_data(benchmark_param_dict)
20+
21+
query_list = benchmark['q_ls']
22+
23+
24+
retriever_init_param_dict = {
25+
"model_name_or_path": "Qwen/Qwen3-Embedding-0.6B",
26+
}
27+
28+
ToolCall.retriever.retriever_init(
29+
**retriever_init_param_dict
30+
)
31+
32+
result = ToolCall.retriever.retriever_search(
33+
query_list=query_list,
34+
top_k=5,
35+
)
36+
37+
retrieve_passages = result['ret_psg']
38+
39+
40+
# Example for PipelineCall usage with rag_deploy.yaml
41+
42+
from ultrarag.api import PipelineCall
43+
44+
result = PipelineCall(
45+
pipeline_file="examples/rag_deploy.yaml",
46+
parameter_file="examples/parameter/rag_deploy_parameter.yaml",
47+
)
48+
49+
final_step_result = result['final_result']
50+
all_steps_result = result['all_results']
51+
52+
53+

src/ultrarag/api.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
import asyncio
2+
import os
3+
from types import SimpleNamespace
4+
from typing import List
5+
6+
import yaml
7+
from fastmcp import Client
8+
9+
from .mcp_logging import get_logger
10+
from . import client as _client_mod
11+
12+
_client: Client | None = None
13+
_servers: List[str] | None = None
14+
SERVER_ROOT = ""
15+
logger = None
16+
17+
18+
class _CallWrapper:
19+
"""Wraps a MCP tool so it can be called like a normal Python function."""
20+
21+
def __init__(self, client: Client, server: str, tool: str, multi: bool):
22+
self._client = client
23+
self._server = server
24+
self._tool = tool
25+
self._multi = multi
26+
27+
async def _ensure_client(self):
28+
global _client, logger
29+
if _client is None:
30+
raise RuntimeError(
31+
"[UltraRAG Error] ToolCall was used before `initialize()` was called."
32+
)
33+
try:
34+
_ = _client.session
35+
except RuntimeError:
36+
await _client.__aenter__()
37+
tools = await _client.list_tools()
38+
tool_name_lst = [
39+
tool.name
40+
for tool in tools
41+
if not tool.name.endswith("_build" if "_" in tool.name else "build")
42+
]
43+
logger.info(f"Available tools: {tool_name_lst}")
44+
45+
async def _async_call(self, *args, **kwargs):
46+
global _client, SERVER_ROOT
47+
await self._ensure_client()
48+
concated = f"{self._server}_{self._tool}" if self._multi else self._tool
49+
param_file = os.path.join(SERVER_ROOT, self._server, "parameter.yaml")
50+
if os.path.exists(param_file):
51+
with open(param_file, "r") as f:
52+
parameter = yaml.safe_load(f)
53+
else:
54+
parameter = {}
55+
56+
with open(os.path.join(SERVER_ROOT, self._server, "server.yaml"), "r") as f:
57+
try:
58+
input_param = yaml.safe_load(f)["tools"][self._tool]["input"]
59+
input_keys = list(input_param.keys())
60+
except:
61+
raise ValueError(
62+
f"[UltraRAG Error] Tool {self._tool} not found in server {self._server} configuration!"
63+
)
64+
65+
for k, v in list(input_param.items()):
66+
if isinstance(v, str) and v.startswith("$"):
67+
key = v[1:]
68+
if key not in parameter:
69+
continue
70+
input_param[k] = parameter[key]
71+
else:
72+
input_param[k] = None
73+
74+
if len(args) > len(input_param):
75+
raise ValueError(
76+
f"[UltraRAG Error] Expected at most {len(input_param)} positional args, got {len(args)}"
77+
)
78+
for pos, value in enumerate(args):
79+
key = input_keys[pos]
80+
input_param[key] = value
81+
82+
for k, v in kwargs.items():
83+
if k not in input_param:
84+
raise ValueError(f"[UltraRAG Error] Unexpected keyword arg: {k!r}")
85+
input_param[k] = v
86+
87+
missing = [k for k, v in input_param.items() if v is None]
88+
if missing:
89+
raise ValueError(f"[UltraRAG Error] Missing value for key(s): {missing}")
90+
result = await self._client.call_tool(concated, input_param)
91+
return result.data if result else None
92+
93+
def __call__(self, *args, **kwargs):
94+
# return asyncio.run(self._async_call(*args, **kwargs))
95+
loop = asyncio.get_event_loop_policy().get_event_loop()
96+
if loop.is_running():
97+
return loop.create_task(self._async_call(*args, **kwargs))
98+
else:
99+
return loop.run_until_complete(self._async_call(*args, **kwargs))
100+
101+
102+
class _ServerProxy(SimpleNamespace):
103+
"""
104+
Proxy for a specific server, e.g. `ToolCall.retriever`.
105+
106+
Accessing an attribute on this object (e.g. `.retriever_search`) returns a
107+
`_CallWrapper` bound to that (server, tool) pair.
108+
"""
109+
def __init__(self, client: Client, name: str, multi: bool):
110+
super().__init__()
111+
self._client = client
112+
self._name = name
113+
self._multi = multi
114+
115+
def __getattr__(self, tool_name: str):
116+
return _CallWrapper(self._client, self._name, tool_name, self._multi)
117+
118+
119+
class _Router(SimpleNamespace):
120+
"""
121+
Top-level router for ToolCall.
122+
123+
Example:
124+
ToolCall.retriever.retriever_search(...)
125+
ToolCall.benchmark.get_data(...)
126+
"""
127+
def __getattr__(self, server: str):
128+
global _client, _servers
129+
if server not in _servers:
130+
raise AttributeError(f"Server {server} has not been initialized!")
131+
return _ServerProxy(_client, server, len(_servers) > 1)
132+
133+
134+
def initialize(servers: list[str], server_root: str, log_level="info"):
135+
"""
136+
Initialize MCP servers so they can be accessed via ToolCall.
137+
"""
138+
global _client, _servers, SERVER_ROOT, logger
139+
logger = get_logger("Client", log_level)
140+
SERVER_ROOT = server_root
141+
mcp_cfg = {"mcpServers": {}}
142+
for server_name in servers:
143+
path = os.path.join(server_root, server_name, "src", f"{server_name}.py")
144+
if not os.path.exists(path):
145+
raise ValueError(f"Server path {path} does not exist!")
146+
mcp_cfg["mcpServers"][server_name] = {
147+
"command": "python",
148+
"args": [path],
149+
"env": os.environ.copy(),
150+
}
151+
152+
_client = Client(mcp_cfg)
153+
_servers = servers
154+
155+
156+
ToolCall = _Router()
157+
158+
159+
async def _pipeline_async(
160+
pipeline_file: str,
161+
parameter_file: str,
162+
log_level: str = "error",
163+
):
164+
"""
165+
Internal async helper that runs a full UltraRAG pipeline with
166+
an explicitly provided parameter file.
167+
"""
168+
_client_mod.logger = get_logger("Client", log_level)
169+
170+
return await _client_mod.run(pipeline_file, parameter_file, return_all=True)
171+
172+
173+
def PipelineCall(
174+
pipeline_file: str,
175+
parameter_file: str,
176+
log_level: str = "error",
177+
):
178+
"""
179+
Run a full UltraRAG pipeline from Python, similar to `ultrarag run`,
180+
but with an explicitly provided parameter file.
181+
"""
182+
loop = asyncio.get_event_loop_policy().get_event_loop()
183+
if loop.is_running():
184+
return loop.create_task(
185+
_pipeline_async(pipeline_file, parameter_file, log_level)
186+
)
187+
else:
188+
return loop.run_until_complete(
189+
_pipeline_async(pipeline_file, parameter_file, log_level)
190+
)

0 commit comments

Comments
 (0)