Skip to content

Commit 27e2fc4

Browse files
committed
update cli with typer
1 parent 169c2d9 commit 27e2fc4

File tree

3 files changed

+105
-365
lines changed

3 files changed

+105
-365
lines changed

pyproject.toml

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ description = "An open source proxy for MCP servers"
55
readme = "README.md"
66
requires-python = ">=3.12"
77
dependencies = [
8-
"logfire>=2.6.2",
9-
"mcp>=1.1.0",
8+
"mcp>=1.2.0rc1",
9+
"typer>=0.15.1",
1010
]
1111

1212
[tool.hatch.version]
@@ -31,9 +31,6 @@ addopts = [
3131
asyncio_mode = "auto"
3232
asyncio_default_fixture_loop_scope = "function"
3333

34-
[tool.uv.sources]
35-
mcp = { git = "https://github.com/modelcontextprotocol/python-sdk.git", rev = "main" }
36-
3734
[build-system]
3835
requires = ["hatchling"]
3936
build-backend = "hatchling.build"

src/omproxy/cli.py

Lines changed: 70 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,90 +1,83 @@
11
#!/usr/bin/env python3
22

3-
import argparse
4-
import logging
53
import os
6-
import uuid
7-
from contextvars import ContextVar
8-
from pathlib import Path
4+
from typing import Optional
5+
import typer
6+
from typing_extensions import Annotated
97

10-
import anyio
11-
import logfire
128
from mcp.client.stdio import StdioServerParameters
13-
149
from omproxy import __version__
15-
from omproxy.highlevel_proxy import run_stdio_client
16-
17-
# Create a global context variable for instance_id
18-
instance_id_var = ContextVar("instance_id", default=None)
19-
20-
21-
def get_or_create_instance_id() -> str:
22-
"""Get or create a persistent UUID for this proxy instance."""
23-
id_file = Path.home() / ".omproxy" / "instance_id"
24-
id_file.parent.mkdir(parents=True, exist_ok=True)
25-
26-
if id_file.exists():
27-
return id_file.read_text().strip()
28-
29-
instance_id = str(uuid.uuid4())
30-
id_file.write_text(instance_id)
31-
return instance_id
32-
33-
34-
def main():
35-
parser = argparse.ArgumentParser(
36-
description="Bidirectional proxy for subprocess communication"
37-
)
38-
parser.add_argument(
39-
"--name",
40-
"-n",
41-
type=str,
42-
help="Name of the service",
43-
)
44-
parser.add_argument(
45-
"--version", action="version", version=__version__, help="Show version and exit"
46-
)
47-
parser.add_argument(
48-
"-v", "--verbose", action="store_true", help="Enable debug logging"
49-
)
50-
parser.add_argument("command", help="Command to run with optional arguments")
51-
parser.add_argument(
52-
"args", nargs=argparse.REMAINDER, help="Arguments to pass to the command"
53-
)
54-
args = parser.parse_args()
55-
56-
# TODO: (use auth see https://github.com/pydantic/logfire/issues/651#issuecomment-2522714987)
57-
os.environ["LOGFIRE_TOKEN"] = "BHVQS0FylRTlf3j50WHNzh8S6ypPCJ308cjcyrdNp3Jc"
58-
os.environ["LOGFIRE_PROJECT_NAME"] = "iod-mcp"
59-
os.environ["LOGFIRE_PROJECT_URL"] = "https://logfire.pydantic.dev/grll/iod-mcp"
60-
os.environ["LOGFIRE_API_URL"] = "https://logfire-api.pydantic.dev"
61-
62-
instance_id = get_or_create_instance_id()
63-
instance_id_var.set(instance_id)
64-
65-
# Configure logging
66-
logfire.configure(
67-
service_name=f"omproxy[{args.name}]",
68-
service_version=__version__,
69-
console=False,
10+
from omproxy.proxy import SseProxy, StdioProxy
11+
12+
app = typer.Typer(no_args_is_help=True)
13+
14+
15+
def version_callback(value: bool):
16+
if value:
17+
typer.echo(f"omproxy version: {__version__}")
18+
raise typer.Exit()
19+
20+
21+
@app.callback()
22+
def callback(
23+
version: Annotated[
24+
bool,
25+
typer.Option(
26+
"--version", callback=version_callback, help="Show version and exit"
27+
),
28+
] = False,
29+
):
30+
"""Bidirectional proxy for subprocess communication."""
31+
pass
32+
33+
34+
@app.command()
35+
def sse(
36+
url: Annotated[str, typer.Option(help="SSE server URL")],
37+
headers: Annotated[
38+
Optional[str], typer.Option(help="SSE headers as key1=value1,key2=value2")
39+
] = None,
40+
timeout: Annotated[float, typer.Option(help="SSE connection timeout")] = 5.0,
41+
sse_read_timeout: Annotated[float, typer.Option(help="SSE read timeout")] = 300.0,
42+
):
43+
"""Use SSE proxy protocol"""
44+
# Parse headers if provided
45+
headers_dict = {}
46+
if headers:
47+
try:
48+
headers_dict = dict(h.split("=") for h in headers.split(","))
49+
except ValueError:
50+
typer.echo(
51+
"Error: Invalid headers format. Use key1=value1,key2=value2", err=True
52+
)
53+
raise typer.Exit(1)
54+
55+
proxy = SseProxy()
56+
proxy.run(
57+
url=url,
58+
headers=headers_dict,
59+
timeout=timeout,
60+
sse_read_timeout=sse_read_timeout,
7061
)
71-
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
7262

73-
logfire.info(
74-
"starting_proxy", command=args.command, args=args.args, instance_id=instance_id
75-
)
7663

77-
async def run_proxy():
78-
await run_stdio_client(
79-
StdioServerParameters(
80-
command=args.command,
81-
args=args.args,
82-
env=os.environ,
83-
)
64+
@app.command()
65+
def stdio(
66+
command: Annotated[str, typer.Argument(help="Command to run")],
67+
args: Annotated[
68+
Optional[list[str]], typer.Argument(help="Arguments for the command")
69+
] = None,
70+
):
71+
"""Use stdio proxy protocol"""
72+
proxy = StdioProxy()
73+
proxy.run(
74+
StdioServerParameters(
75+
command=command,
76+
args=args or [],
77+
env=os.environ,
8478
)
85-
86-
anyio.run(run_proxy)
79+
)
8780

8881

8982
if __name__ == "__main__":
90-
main()
83+
app()

0 commit comments

Comments
 (0)