60
60
from sse_starlette .sse import EventSourceResponse
61
61
import uvicorn
62
62
63
+ # Conditional imports
64
+ try :
65
+ # Third-Party
66
+ from fastapi .middleware .cors import CORSMiddleware
67
+ except ImportError :
68
+ CORSMiddleware = None
69
+
70
+ try :
71
+ # Third-Party
72
+ import httpx
73
+ except ImportError :
74
+ httpx = None
75
+
63
76
LOGGER = logging .getLogger ("mcpgateway.translate" )
64
77
KEEP_ALIVE_INTERVAL = 30 # seconds - matches the reference implementation
65
78
__all__ = ["main" ] # for console-script entry-point
@@ -75,6 +88,14 @@ def __init__(self) -> None:
75
88
self ._subscribers : List [asyncio .Queue [str ]] = []
76
89
77
90
async def publish (self , data : str ) -> None :
91
+ """Publish data to all subscribers.
92
+
93
+ Args:
94
+ data: The data string to publish to all subscribers.
95
+
96
+ Returns:
97
+ None
98
+ """
78
99
dead : List [asyncio .Queue [str ]] = []
79
100
for q in self ._subscribers :
80
101
try :
@@ -86,11 +107,24 @@ async def publish(self, data: str) -> None:
86
107
self ._subscribers .remove (q )
87
108
88
109
def subscribe (self ) -> "asyncio.Queue[str]" :
110
+ """Subscribe to published data.
111
+
112
+ Returns:
113
+ asyncio.Queue[str]: A queue that will receive published data.
114
+ """
89
115
q : asyncio .Queue [str ] = asyncio .Queue (maxsize = 1024 )
90
116
self ._subscribers .append (q )
91
117
return q
92
118
93
119
def unsubscribe (self , q : "asyncio.Queue[str]" ) -> None :
120
+ """Unsubscribe from published data.
121
+
122
+ Args:
123
+ q: The queue to unsubscribe from published data.
124
+
125
+ Returns:
126
+ None
127
+ """
94
128
with suppress (ValueError ):
95
129
self ._subscribers .remove (q )
96
130
@@ -109,6 +143,16 @@ def __init__(self, cmd: str, pubsub: _PubSub) -> None:
109
143
self ._pump_task : Optional [asyncio .Task [None ]] = None
110
144
111
145
async def start (self ) -> None :
146
+ """Start the stdio subprocess.
147
+
148
+ Creates the subprocess and starts the stdout pump task.
149
+
150
+ Returns:
151
+ None
152
+
153
+ Raises:
154
+ OSError: If the subprocess cannot be started.
155
+ """
112
156
LOGGER .info ("Starting stdio subprocess: %s" , self ._cmd )
113
157
self ._proc = await asyncio .create_subprocess_exec (
114
158
* shlex .split (self ._cmd ),
@@ -121,6 +165,13 @@ async def start(self) -> None:
121
165
self ._pump_task = asyncio .create_task (self ._pump_stdout ())
122
166
123
167
async def stop (self ) -> None :
168
+ """Stop the stdio subprocess.
169
+
170
+ Terminates the subprocess and cancels the pump task.
171
+
172
+ Returns:
173
+ None
174
+ """
124
175
if self ._proc is None :
125
176
return
126
177
LOGGER .info ("Stopping subprocess (pid=%s)" , self ._proc .pid )
@@ -131,13 +182,35 @@ async def stop(self) -> None:
131
182
self ._pump_task .cancel ()
132
183
133
184
async def send (self , raw : str ) -> None :
185
+ """Send data to the subprocess stdin.
186
+
187
+ Args:
188
+ raw: The raw data string to send to the subprocess.
189
+
190
+ Returns:
191
+ None
192
+
193
+ Raises:
194
+ RuntimeError: If the stdio endpoint is not started.
195
+ """
134
196
if not self ._stdin :
135
197
raise RuntimeError ("stdio endpoint not started" )
136
198
LOGGER .debug ("→ stdio: %s" , raw .strip ())
137
199
self ._stdin .write (raw .encode ())
138
200
await self ._stdin .drain ()
139
201
140
202
async def _pump_stdout (self ) -> None :
203
+ """Pump stdout from subprocess to pubsub.
204
+
205
+ Continuously reads lines from the subprocess stdout and publishes them
206
+ to the pubsub system.
207
+
208
+ Returns:
209
+ None
210
+
211
+ Raises:
212
+ Exception: If the stdout pump encounters an error.
213
+ """
141
214
assert self ._proc and self ._proc .stdout
142
215
reader = self ._proc .stdout
143
216
try :
@@ -168,13 +241,23 @@ def _build_fastapi(
168
241
message_path : str = "/message" ,
169
242
cors_origins : Optional [List [str ]] = None ,
170
243
) -> FastAPI :
244
+ """Build FastAPI application with SSE and message endpoints.
245
+
246
+ Args:
247
+ pubsub: The publish/subscribe system for message routing.
248
+ stdio: The stdio endpoint for subprocess communication.
249
+ keep_alive: Interval in seconds for keepalive messages. Defaults to KEEP_ALIVE_INTERVAL.
250
+ sse_path: Path for the SSE endpoint. Defaults to "/sse".
251
+ message_path: Path for the message endpoint. Defaults to "/message".
252
+ cors_origins: Optional list of CORS allowed origins.
253
+
254
+ Returns:
255
+ FastAPI: The configured FastAPI application.
256
+ """
171
257
app = FastAPI ()
172
258
173
259
# Add CORS middleware if origins specified
174
- if cors_origins :
175
- # Third-Party
176
- from fastapi .middleware .cors import CORSMiddleware
177
-
260
+ if cors_origins and CORSMiddleware :
178
261
app .add_middleware (
179
262
CORSMiddleware ,
180
263
allow_origins = cors_origins ,
@@ -254,6 +337,7 @@ async def post_message(raw: Request, session_id: str | None = None) -> Response:
254
337
Response: ``202 Accepted`` if the payload is forwarded successfully,
255
338
or ``400 Bad Request`` when the body is not valid JSON.
256
339
"""
340
+ _ = session_id # Unused but required for API compatibility
257
341
payload = await raw .body ()
258
342
try :
259
343
json .loads (payload ) # validate
@@ -268,6 +352,11 @@ async def post_message(raw: Request, session_id: str | None = None) -> Response:
268
352
# ----- Liveness ---------------------------------------------------------#
269
353
@app .get ("/healthz" )
270
354
async def health () -> Response : # noqa: D401
355
+ """Health check endpoint.
356
+
357
+ Returns:
358
+ Response: A plain text response with "ok" status.
359
+ """
271
360
return PlainTextResponse ("ok" )
272
361
273
362
return app
@@ -279,6 +368,17 @@ async def health() -> Response: # noqa: D401
279
368
280
369
281
370
def _parse_args (argv : Sequence [str ]) -> argparse .Namespace :
371
+ """Parse command line arguments.
372
+
373
+ Args:
374
+ argv: Sequence of command line arguments.
375
+
376
+ Returns:
377
+ argparse.Namespace: Parsed command line arguments.
378
+
379
+ Raises:
380
+ NotImplementedError: If streamableHttp option is specified.
381
+ """
282
382
p = argparse .ArgumentParser (
283
383
prog = "mcpgateway.translate" ,
284
384
description = "Bridges stdio JSON-RPC to SSE or SSE to stdio." ,
@@ -312,6 +412,17 @@ def _parse_args(argv: Sequence[str]) -> argparse.Namespace:
312
412
313
413
314
414
async def _run_stdio_to_sse (cmd : str , port : int , log_level : str = "info" , cors : Optional [List [str ]] = None ) -> None :
415
+ """Run stdio to SSE bridge.
416
+
417
+ Args:
418
+ cmd: The command to run as a stdio subprocess.
419
+ port: The port to bind the HTTP server to.
420
+ log_level: The logging level to use. Defaults to "info".
421
+ cors: Optional list of CORS allowed origins.
422
+
423
+ Returns:
424
+ None
425
+ """
315
426
pubsub = _PubSub ()
316
427
stdio = StdIOEndpoint (cmd , pubsub )
317
428
await stdio .start ()
@@ -346,9 +457,21 @@ async def _shutdown() -> None:
346
457
await _shutdown () # final cleanup
347
458
348
459
349
- async def _run_sse_to_stdio (url : str , oauth2_bearer : Optional [str ], log_level : str = "info" ) -> None :
350
- # Third-Party
351
- import httpx
460
+ async def _run_sse_to_stdio (url : str , oauth2_bearer : Optional [str ]) -> None :
461
+ """Run SSE to stdio bridge.
462
+
463
+ Args:
464
+ url: The SSE endpoint URL to connect to.
465
+ oauth2_bearer: Optional OAuth2 bearer token for authentication.
466
+
467
+ Returns:
468
+ None
469
+
470
+ Raises:
471
+ ImportError: If httpx package is not available.
472
+ """
473
+ if not httpx :
474
+ raise ImportError ("httpx package is required for SSE to stdio bridging" )
352
475
353
476
headers = {}
354
477
if oauth2_bearer :
@@ -384,24 +507,56 @@ async def pump_sse_to_stdio():
384
507
385
508
386
509
def start_stdio (cmd , port , log_level , cors ):
510
+ """Start stdio bridge.
511
+
512
+ Args:
513
+ cmd: The command to run as a stdio subprocess.
514
+ port: The port to bind the HTTP server to.
515
+ log_level: The logging level to use.
516
+ cors: Optional list of CORS allowed origins.
517
+
518
+ Returns:
519
+ None
520
+ """
387
521
return asyncio .run (_run_stdio_to_sse (cmd , port , log_level , cors ))
388
522
389
523
390
- def start_sse (url , bearer , log_level ):
391
- return asyncio .run (_run_sse_to_stdio (url , bearer , log_level ))
524
+ def start_sse (url , bearer ):
525
+ """Start SSE bridge.
526
+
527
+ Args:
528
+ url: The SSE endpoint URL to connect to.
529
+ bearer: Optional OAuth2 bearer token for authentication.
530
+
531
+ Returns:
532
+ None
533
+ """
534
+ return asyncio .run (_run_sse_to_stdio (url , bearer ))
535
+
536
+
537
+ def main (argv : Optional [Sequence [str ]] | None = None ) -> None :
538
+ """Entry point for the translate module.
539
+
540
+ Args:
541
+ argv: Optional sequence of command line arguments. If None, uses sys.argv[1:].
392
542
543
+ Returns:
544
+ None
393
545
394
- def main (argv : Optional [Sequence [str ]] | None = None ) -> None : # entry-point
546
+ Raises:
547
+ NotImplementedError: If an unsupported option is specified.
548
+ KeyboardInterrupt: If the user interrupts the process.
549
+ """
395
550
args = _parse_args (argv or sys .argv [1 :])
396
551
logging .basicConfig (
397
552
level = getattr (logging , args .logLevel .upper (), logging .INFO ),
398
553
format = "%(asctime)s [%(levelname)s] %(name)s: %(message)s" ,
399
554
)
400
555
try :
401
556
if args .stdio :
402
- return start_stdio (args .stdio , args .port , args .logLevel , args .cors )
557
+ start_stdio (args .stdio , args .port , args .logLevel , args .cors )
403
558
elif args .sse :
404
- return start_sse (args .sse , args .oauth2Bearer , args . logLevel )
559
+ start_sse (args .sse , args .oauth2Bearer )
405
560
except KeyboardInterrupt :
406
561
print ("" ) # restore shell prompt
407
562
sys .exit (0 )
0 commit comments