Skip to content

Commit 812c1f8

Browse files
committed
feat: add JSON config parsing with input token replacement
1 parent 6e23a6a commit 812c1f8

File tree

1 file changed

+51
-20
lines changed
  • src/datapilot/core/mcp_utils

1 file changed

+51
-20
lines changed

src/datapilot/core/mcp_utils/mcp.py

Lines changed: 51 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,32 +21,65 @@ def create_mcp_proxy():
2121
content = click.edit()
2222
if content is None:
2323
click.echo("No input provided.")
24+
return
2425

25-
output = asyncio.run(list_tools())
26-
click.echo(json.dumps(output, indent=2))
26+
try:
27+
config = json.loads(content)
28+
except json.JSONDecodeError:
29+
click.echo("Invalid JSON content.")
30+
return
2731

28-
async def list_tools(command: str, args: list[str], env: dict[str, str]) -> str:
29-
command = shutil.which(command)
32+
inputs = {}
33+
mcp_config = config.get("mcp", {})
34+
35+
# Process inputs first
36+
for input_def in mcp_config.get("inputs", []):
37+
input_id = input_def["id"]
38+
inputs[input_id] = click.prompt(
39+
input_def.get("description", input_id),
40+
hide_input=input_def.get("password", False)
41+
)
42+
43+
# Process servers
44+
servers = mcp_config.get("servers", {})
45+
for server_name, server_config in servers.items():
46+
# Replace input tokens in args
47+
processed_args = [
48+
inputs.get(arg[8:-1], arg) if isinstance(arg, str) and arg.startswith("${input:") else arg
49+
for arg in server_config.get("args", [])
50+
]
51+
52+
# Replace input tokens in environment variables
53+
processed_env = {
54+
k: inputs.get(v[8:-1], v) if isinstance(v, str) and v.startswith("${input:") else v
55+
for k, v in server_config.get("env", {}).items()
56+
}
57+
58+
# Execute with processed parameters
59+
output = asyncio.run(list_tools(
60+
command=server_config["command"],
61+
args=processed_args,
62+
env=processed_env
63+
))
64+
click.echo(f"\nServer: {server_name}")
65+
click.echo(json.dumps(output, indent=2))
66+
67+
async def list_tools(command: str, args: list[str], env: dict[str, str]):
68+
command_path = shutil.which(command)
69+
if not command_path:
70+
raise click.UsageError(f"Command not found: {command}")
3071

31-
# Create server parameters for stdio connection
3272
server_params = StdioServerParameters(
33-
command=command, # Executable
34-
args=args, # Optional command line arguments
35-
env=None, # Optional environment variables
73+
command=command_path,
74+
args=args,
75+
env=env, # Now using processed env
3676
)
37-
77+
3878
async with stdio_client(server_params) as (read, write):
39-
async with ClientSession(
40-
read, write
41-
) as session:
42-
# Initialize the connection
79+
async with ClientSession(read, write) as session:
4380
await session.initialize()
44-
45-
# List available tools
4681
tools = await session.list_tools()
47-
48-
# print as json
49-
tools_list = [
82+
return [
5083
{
5184
"name": tool.name,
5285
"description": tool.description,
@@ -55,5 +88,3 @@ async def list_tools(command: str, args: list[str], env: dict[str, str]) -> str:
5588
for tool in tools.tools
5689
]
5790

58-
return tools_list
59-

0 commit comments

Comments
 (0)