forked from cuga-project/cuga-agent
-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathlifecycle.py
More file actions
119 lines (103 loc) · 5.08 KB
/
lifecycle.py
File metadata and controls
119 lines (103 loc) · 5.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from __future__ import annotations
# REVIEW-FIX: Circuit breaker accounting, allowlist pre-flight, and consistent error handling.
import asyncio
import time
from collections import defaultdict, deque
from typing import Dict, Optional, Tuple
from cuga.mcp.errors import CallTimeout, StartupError, ToolUnavailable
from cuga.mcp.interfaces import ToolRequest, ToolResponse, ToolSpec
from cuga.mcp.registry import MCPRegistry
from cuga.mcp.runners.subprocess_stdio import SubprocessStdioRunner
from cuga.mcp.telemetry.logging import setup_json_logging
from cuga.mcp.telemetry.metrics import metrics
LOGGER = setup_json_logging()
class CircuitState:
def __init__(self, threshold: int = 3, cooldown_s: float = 10.0) -> None:
self.threshold = threshold
self.cooldown_s = cooldown_s
self.failures = 0
self.open_until = 0.0
def record_success(self) -> None:
self.failures = 0
self.open_until = 0.0
def record_failure(self) -> None:
self.failures += 1
if self.failures >= self.threshold:
self.open_until = time.time() + self.cooldown_s
def allow(self) -> bool:
if self.open_until and time.time() < self.open_until:
return False
return True
class LifecycleManager:
def __init__(self, registry: Optional[MCPRegistry] = None) -> None:
self.registry = registry or MCPRegistry()
self.runners: Dict[Tuple[str, str], SubprocessStdioRunner] = {}
self.circuits: Dict[str, CircuitState] = defaultdict(CircuitState)
self.pool: Dict[str, deque[SubprocessStdioRunner]] = defaultdict(deque)
def _runner_key(self, spec: ToolSpec) -> Tuple[str, str]:
return (spec.alias, spec.transport)
async def ensure_runner(self, spec: ToolSpec) -> SubprocessStdioRunner:
key = self._runner_key(spec)
if spec.transport != "stdio":
raise ToolUnavailable(f"Unsupported transport: {spec.transport}")
if key in self.runners and self.runners[key].is_healthy():
return self.runners[key]
runner = SubprocessStdioRunner(
command=spec.command or "python",
args=spec.args,
env=spec.env,
working_dir=spec.working_dir,
allowed_commands=self.registry.config.allow_commands,
)
await runner.start()
self.runners[key] = runner
return runner
async def stop_runner(self, alias: str, transport: str = "stdio") -> None:
key = (alias, transport)
runner = self.runners.pop(key, None)
if runner:
await runner.stop()
async def call(self, alias: str, request: ToolRequest) -> ToolResponse:
spec = self.registry.get(alias)
metrics.counter("mcp.calls").inc()
circuit = self.circuits[alias]
if not circuit.allow():
return ToolResponse(ok=False, error="circuit open", metrics={"transport": spec.transport})
try:
runner = await self.ensure_runner(spec)
except ToolUnavailable as exc:
circuit.record_failure()
metrics.counter("mcp.errors", {"kind": "unavailable"}).inc()
return ToolResponse(ok=False, error=str(exc), metrics={"transport": spec.transport})
except StartupError as exc:
circuit.record_failure()
metrics.counter("mcp.errors", {"kind": "startup"}).inc()
return ToolResponse(ok=False, error=str(exc), metrics={"transport": spec.transport})
stop_timer = metrics.time_block("mcp.latency_ms")
try:
payload = {"method": request.method, "params": request.params}
raw = await runner.call_with_retry(payload, timeout=request.timeout_s or spec.timeout_s)
circuit.record_success()
return ToolResponse(ok=True, result=raw.get("result"), metrics={"transport": spec.transport})
except CallTimeout as exc:
circuit.record_failure()
metrics.counter("mcp.errors", {"kind": "timeout"}).inc()
return ToolResponse(ok=False, error=str(exc), metrics={"transport": spec.transport})
except StartupError as exc:
circuit.record_failure()
metrics.counter("mcp.errors", {"kind": "startup"}).inc()
return ToolResponse(ok=False, error=str(exc), metrics={"transport": spec.transport})
except ToolUnavailable as exc:
circuit.record_failure()
metrics.counter("mcp.errors", {"kind": "unavailable"}).inc()
return ToolResponse(ok=False, error=str(exc), metrics={"transport": spec.transport})
except Exception: # REVIEW-FIX: keep callers stable by returning error
circuit.record_failure()
metrics.counter("mcp.errors", {"kind": "unexpected"}).inc()
LOGGER.exception("Unexpected MCP failure", extra={"alias": alias})
return ToolResponse(ok=False, error="unexpected error", metrics={"transport": spec.transport})
finally:
stop_timer()
async def stop_all(self) -> None:
await asyncio.gather(*(runner.stop() for runner in list(self.runners.values())))
self.runners.clear()