Skip to content

Commit ddef094

Browse files
committed
Bug fixes
Signed-off-by: Madhav Kandukuri <[email protected]>
1 parent a9c7b03 commit ddef094

File tree

6 files changed

+106
-33
lines changed

6 files changed

+106
-33
lines changed

mcpgateway/cli.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,26 @@ def _needs_app(arg_list: List[str]) -> bool:
6060
is taken as the application path. We therefore look at the first
6161
element of *arg_list* (if any) – if it *starts* with a dash it must be
6262
an option, hence the app path is missing and we should inject ours.
63+
64+
Args:
65+
arg_list (List[str]): List of arguments
66+
67+
Returns:
68+
bool: Returns *True* when the CLI invocation has *no* positional APP path
6369
"""
6470

6571
return len(arg_list) == 0 or arg_list[0].startswith("-")
6672

6773

6874
def _insert_defaults(raw_args: List[str]) -> List[str]:
69-
"""Return a *new* argv with defaults sprinkled in where needed."""
75+
"""Return a *new* argv with defaults sprinkled in where needed.
76+
77+
Args:
78+
raw_args (List[str]): List of input arguments to cli
79+
80+
Returns:
81+
List[str]: List of arguments
82+
"""
7083

7184
args = list(raw_args) # shallow copy – we'll mutate this
7285

mcpgateway/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ def _parse_allowed_origins(cls, v):
113113
# For federation_peers strip out quotes to ensure we're passing valid JSON via env
114114
federation_peers: Annotated[List[str], NoDecode] = []
115115

116+
# Lock file path for initializing gateway service initialize
117+
lock_file_path: str = "/tmp/gateway_init.done"
118+
116119
@field_validator("federation_peers", mode="before")
117120
@classmethod
118121
def _parse_federation_peers(cls, v):

mcpgateway/main.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import asyncio
2929
import json
3030
import logging
31+
import os
3132
from contextlib import asynccontextmanager
3233
from typing import Any, AsyncIterator, Dict, List, Optional, Union
3334

@@ -170,7 +171,7 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
170171
shuts them down in reverse order on exit.
171172
172173
Args:
173-
app (FastAPI): FastAPI app
174+
_app (FastAPI): FastAPI app
174175
175176
Yields:
176177
None
@@ -184,7 +185,15 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
184185
await tool_service.initialize()
185186
await resource_service.initialize()
186187
await prompt_service.initialize()
187-
await gateway_service.initialize()
188+
try:
189+
# Try to create the file exclusively
190+
fd = os.open(settings.lock_file_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o644)
191+
except FileExistsError:
192+
logger.info("Gateway already initialized by another worker")
193+
else:
194+
with os.fdopen(fd, "w") as lock_file:
195+
lock_file.write("initialized")
196+
await gateway_service.initialize()
188197
await root_service.initialize()
189198
await completion_service.initialize()
190199
await logging_service.initialize()
@@ -2019,7 +2028,9 @@ async def root_redirect(request: Request):
20192028
RedirectResponse: Redirects to /admin.
20202029
"""
20212030
logger.debug("Redirecting root path to /admin")
2022-
return RedirectResponse(request.url_for("admin_home"))
2031+
root_path = request.scope.get("root_path", "")
2032+
return RedirectResponse(f"{root_path}/admin", status_code=303)
2033+
# return RedirectResponse(request.url_for("admin_home"))
20232034

20242035
else:
20252036
# If UI is disabled, provide API info at root

mcpgateway/services/gateway_service.py

Lines changed: 58 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from mcp import ClientSession
2424
from mcp.client.sse import sse_client
2525
from sqlalchemy import select
26-
from sqlalchemy.exc import IntegrityError
2726
from sqlalchemy.orm import Session
2827

2928
from mcpgateway.config import settings
@@ -122,7 +121,6 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
122121
123122
Raises:
124123
GatewayNameConflictError: If gateway name already exists
125-
GatewayError: If registration fails
126124
"""
127125
try:
128126
# Check for name conflicts (both active and inactive)
@@ -138,13 +136,17 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
138136
auth_type = getattr(gateway, "auth_type", None)
139137
auth_value = getattr(gateway, "auth_value", {})
140138

141-
# Initialize connection and get capabilities
142139
capabilities, tools = await self._initialize_gateway(str(gateway.url), auth_value)
143-
140+
141+
all_names = [td.name for td in tools]
142+
143+
existing_tools = db.execute(select(DbTool).where(DbTool.name.in_(all_names))).scalars().all()
144+
existing_tool_names = [tool.name for tool in existing_tools]
145+
144146
tools = [
145147
DbTool(
146148
name=tool.name,
147-
url=tool.url,
149+
url=str(gateway.url),
148150
description=tool.description,
149151
integration_type=tool.integration_type,
150152
request_type=tool.request_type,
@@ -157,6 +159,9 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
157159
for tool in tools
158160
]
159161

162+
existing_tools = [tool for tool in tools if tool.name in existing_tool_names]
163+
new_tools = [tool for tool in tools if tool.name not in existing_tool_names]
164+
160165
# Create DB model
161166
db_gateway = DbGateway(
162167
name=gateway.name,
@@ -166,7 +171,8 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
166171
last_seen=datetime.now(timezone.utc),
167172
auth_type=auth_type,
168173
auth_value=auth_value,
169-
tools=tools,
174+
tools=new_tools,
175+
# federated_tools=existing_tools + new_tools
170176
)
171177

172178
# Add to DB
@@ -181,12 +187,19 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
181187
await self._notify_gateway_added(db_gateway)
182188

183189
return GatewayRead.model_validate(gateway)
184-
except IntegrityError:
185-
db.rollback()
186-
raise GatewayError(f"Gateway already exists: {gateway.name}")
187-
except Exception as e:
188-
db.rollback()
189-
raise GatewayError(f"Failed to register gateway: {str(e)}")
190+
except* ValueError as ve:
191+
logger.error("ValueErrors in group: %s", ve.exceptions)
192+
except* RuntimeError as re:
193+
logger.error("RuntimeErrors in group: %s", re.exceptions)
194+
except* BaseException as other: # catches every other sub-exception
195+
logger.error("Other grouped errors: %s", other.exceptions)
196+
# except IntegrityError as ex:
197+
# logger.error(f"Error adding gateway: {ex}")
198+
# db.rollback()
199+
# raise GatewayError(f"Gateway already exists: {gateway.name}")
200+
# except Exception as e:
201+
# db.rollback()
202+
# raise GatewayError(f"Failed to register gateway: {str(e)}")
190203

191204
async def list_gateways(self, db: Session, include_inactive: bool = False) -> List[GatewayRead]:
192205
"""List all registered gateways.
@@ -462,28 +475,42 @@ async def forward_request(self, gateway: DbGateway, method: str, params: Optiona
462475
raise GatewayConnectionError(f"Failed to forward request to {gateway.name}: {str(e)}")
463476

464477
async def check_health_of_gateways(self, gateways: List[DbGateway]) -> bool:
465-
"""Health check for gateways
478+
"""Health check for a list of gateways.
479+
480+
Deactivates gateway if gateway is not healthy.
466481
467482
Args:
468-
gateways: Gateways to check
483+
gateways (List[DbGateway]): List of gateways to check if healthy
469484
470485
Returns:
471-
True if gateway is healthy
486+
bool: True if all active gateways are healthy
472487
"""
473-
for gateway in gateways:
474-
if not gateway.is_active:
475-
return False
488+
# Reuse a single HTTP client for all requests
489+
async with httpx.AsyncClient() as client:
490+
for gateway in gateways:
491+
# Inactive gateways are unhealthy
492+
if not gateway.is_active:
493+
continue
476494

477-
try:
478-
# Try to initialize connection
479-
await self._initialize_gateway(gateway.url, gateway.auth_value)
495+
try:
496+
# Ensure auth_value is a dict
497+
auth_data = gateway.auth_value or {}
498+
headers = decode_auth(auth_data)
499+
500+
# Perform the GET and raise on 4xx/5xx
501+
async with client.stream("GET", gateway.url, headers=headers) as response:
502+
# This will raise immediately if status is 4xx/5xx
503+
response.raise_for_status()
504+
505+
# Mark successful check
506+
gateway.last_seen = datetime.utcnow()
480507

481-
# Update last seen
482-
gateway.last_seen = datetime.utcnow()
483-
return True
508+
except Exception:
509+
with SessionLocal() as db:
510+
await self.toggle_gateway_status(db=db, gateway_id=gateway.id, activate=False)
484511

485-
except Exception:
486-
return False
512+
# All gateways passed
513+
return True
487514

488515
async def aggregate_capabilities(self, db: Session) -> Dict[str, Any]:
489516
"""Aggregate capabilities from all gateways.
@@ -584,7 +611,11 @@ async def connect_to_sse_server(server_url: str, authentication: Optional[Dict[s
584611
raise GatewayConnectionError(f"Failed to initialize gateway at {url}: {str(e)}")
585612

586613
def _get_active_gateways(self) -> list[DbGateway]:
587-
"""Sync function for database operations (runs in thread)."""
614+
"""Sync function for database operations (runs in thread).
615+
616+
Returns:
617+
List[DbGateway]: List of active gateways
618+
"""
588619
with SessionLocal() as db:
589620
return db.execute(select(DbGateway).where(DbGateway.is_active)).scalars().all()
590621

@@ -598,7 +629,6 @@ async def _run_health_checks(self) -> None:
598629
if len(gateways) > 0:
599630
# Async health checks (non-blocking)
600631
await self.check_health_of_gateways(gateways)
601-
602632
except Exception as e:
603633
logger.error(f"Health check run failed: {str(e)}")
604634

mcpgateway/services/tool_service.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,12 +518,15 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any]) -
518518
else:
519519
headers = {}
520520

521-
async def connect_to_sse_server(server_url: str):
521+
async def connect_to_sse_server(server_url: str) -> str:
522522
"""
523523
Connect to an MCP server running with SSE transport
524524
525525
Args:
526526
server_url (str): MCP Server SSE URL
527+
528+
Returns:
529+
str: Result of tool call
527530
"""
528531
# Use async with directly to manage the context
529532
async with sse_client(url=server_url, headers=headers) as streams:

tests/hey/payload2.json

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"jsonrpc": "2.0",
3+
"id": 1,
4+
"method": "tools/call",
5+
"params": {
6+
"name": "convert_time",
7+
"arguments": {
8+
"source_timezone": "Europe/Berlin",
9+
"target_timezone": "Europe/Dublin",
10+
"time": "09:00"
11+
}
12+
}
13+
}

0 commit comments

Comments
 (0)