Skip to content

Commit ffd73b4

Browse files
authored
Merge pull request #72 from IBM/delete-gateway-fix
Fix gateway deletion and addition
2 parents ef44aee + ddef094 commit ffd73b4

File tree

8 files changed

+202
-118
lines changed

8 files changed

+202
-118
lines changed

mcpgateway/cli.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,26 @@ def _needs_app(arg_list: List[str]) -> bool:
6060
is taken as the application path. We therefore look at the first
6161
element of *arg_list* (if any) – if it *starts* with a dash it must be
6262
an option, hence the app path is missing and we should inject ours.
63+
64+
Args:
65+
arg_list (List[str]): List of arguments
66+
67+
Returns:
68+
bool: Returns *True* when the CLI invocation has *no* positional APP path
6369
"""
6470

6571
return len(arg_list) == 0 or arg_list[0].startswith("-")
6672

6773

6874
def _insert_defaults(raw_args: List[str]) -> List[str]:
69-
"""Return a *new* argv with defaults sprinkled in where needed."""
75+
"""Return a *new* argv with defaults sprinkled in where needed.
76+
77+
Args:
78+
raw_args (List[str]): List of input arguments to cli
79+
80+
Returns:
81+
List[str]: List of arguments
82+
"""
7083

7184
args = list(raw_args) # shallow copy – we'll mutate this
7285

mcpgateway/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ def _parse_allowed_origins(cls, v):
122122
# For federation_peers strip out quotes to ensure we're passing valid JSON via env
123123
federation_peers: Annotated[List[str], NoDecode] = []
124124

125+
# Lock file path for initializing gateway service initialize
126+
lock_file_path: str = "/tmp/gateway_init.done"
127+
125128
@field_validator("federation_peers", mode="before")
126129
@classmethod
127130
def _parse_federation_peers(cls, v):

mcpgateway/db.py

Lines changed: 70 additions & 65 deletions
Large diffs are not rendered by default.

mcpgateway/main.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import asyncio
2929
import json
3030
import logging
31+
import os
3132
from contextlib import asynccontextmanager
3233
from typing import Any, AsyncIterator, Dict, List, Optional, Union
3334

@@ -170,7 +171,7 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
170171
shuts them down in reverse order on exit.
171172
172173
Args:
173-
app (FastAPI): FastAPI app
174+
_app (FastAPI): FastAPI app
174175
175176
Yields:
176177
None
@@ -184,7 +185,15 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
184185
await tool_service.initialize()
185186
await resource_service.initialize()
186187
await prompt_service.initialize()
187-
await gateway_service.initialize()
188+
try:
189+
# Try to create the file exclusively
190+
fd = os.open(settings.lock_file_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o644)
191+
except FileExistsError:
192+
logger.info("Gateway already initialized by another worker")
193+
else:
194+
with os.fdopen(fd, "w") as lock_file:
195+
lock_file.write("initialized")
196+
await gateway_service.initialize()
188197
await root_service.initialize()
189198
await completion_service.initialize()
190199
await logging_service.initialize()
@@ -2052,7 +2061,9 @@ async def root_redirect(request: Request):
20522061
RedirectResponse: Redirects to /admin.
20532062
"""
20542063
logger.debug("Redirecting root path to /admin")
2055-
return RedirectResponse(request.url_for("admin_home"))
2064+
root_path = request.scope.get("root_path", "")
2065+
return RedirectResponse(f"{root_path}/admin", status_code=303)
2066+
# return RedirectResponse(request.url_for("admin_home"))
20562067

20572068
else:
20582069
# If UI is disabled, provide API info at root

mcpgateway/schemas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ class ToolRead(BaseModelWithConfig):
402402

403403
id: int
404404
name: str
405-
url: str
405+
url: Optional[str]
406406
description: Optional[str]
407407
request_type: str
408408
integration_type: str

mcpgateway/services/gateway_service.py

Lines changed: 83 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,13 @@
1616

1717
import asyncio
1818
import logging
19-
from datetime import datetime
19+
from datetime import datetime, timezone
2020
from typing import Any, AsyncGenerator, Dict, List, Optional, Set
2121

2222
import httpx
2323
from mcp import ClientSession
2424
from mcp.client.sse import sse_client
2525
from sqlalchemy import select
26-
from sqlalchemy.exc import IntegrityError
2726
from sqlalchemy.orm import Session
2827

2928
from mcpgateway.config import settings
@@ -122,7 +121,6 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
122121
123122
Raises:
124123
GatewayNameConflictError: If gateway name already exists
125-
GatewayError: If registration fails
126124
"""
127125
try:
128126
# Check for name conflicts (both active and inactive)
@@ -138,12 +136,43 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
138136
auth_type = getattr(gateway, "auth_type", None)
139137
auth_value = getattr(gateway, "auth_value", {})
140138

141-
# Initialize connection and get capabilities
142139
capabilities, tools = await self._initialize_gateway(str(gateway.url), auth_value)
143140

141+
all_names = [td.name for td in tools]
142+
143+
existing_tools = db.execute(select(DbTool).where(DbTool.name.in_(all_names))).scalars().all()
144+
existing_tool_names = [tool.name for tool in existing_tools]
145+
146+
tools = [
147+
DbTool(
148+
name=tool.name,
149+
url=str(gateway.url),
150+
description=tool.description,
151+
integration_type=tool.integration_type,
152+
request_type=tool.request_type,
153+
headers=tool.headers,
154+
input_schema=tool.input_schema,
155+
jsonpath_filter=tool.jsonpath_filter,
156+
auth_type=auth_type,
157+
auth_value=auth_value,
158+
)
159+
for tool in tools
160+
]
161+
162+
existing_tools = [tool for tool in tools if tool.name in existing_tool_names]
163+
new_tools = [tool for tool in tools if tool.name not in existing_tool_names]
164+
144165
# Create DB model
145166
db_gateway = DbGateway(
146-
name=gateway.name, url=str(gateway.url), description=gateway.description, capabilities=capabilities, last_seen=datetime.utcnow(), auth_type=auth_type, auth_value=auth_value
167+
name=gateway.name,
168+
url=str(gateway.url),
169+
description=gateway.description,
170+
capabilities=capabilities,
171+
last_seen=datetime.now(timezone.utc),
172+
auth_type=auth_type,
173+
auth_value=auth_value,
174+
tools=new_tools,
175+
# federated_tools=existing_tools + new_tools
147176
)
148177

149178
# Add to DB
@@ -157,23 +186,20 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
157186
# Notify subscribers
158187
await self._notify_gateway_added(db_gateway)
159188

160-
inserted_gateway = db.execute(select(DbGateway).where(DbGateway.name == gateway.name)).scalar_one_or_none()
161-
inserted_gateway_id = inserted_gateway.id
162-
163-
logger.info(f"Registered gateway: {gateway.name}")
164-
165-
for tool in tools:
166-
tool.gateway_id = inserted_gateway_id
167-
await self.tool_service.register_tool(db=db, tool=tool)
168-
169189
return GatewayRead.model_validate(gateway)
170-
171-
except IntegrityError:
172-
db.rollback()
173-
raise GatewayError(f"Gateway already exists: {gateway.name}")
174-
except Exception as e:
175-
db.rollback()
176-
raise GatewayError(f"Failed to register gateway: {str(e)}")
190+
except* ValueError as ve:
191+
logger.error("ValueErrors in group: %s", ve.exceptions)
192+
except* RuntimeError as re:
193+
logger.error("RuntimeErrors in group: %s", re.exceptions)
194+
except* BaseException as other: # catches every other sub-exception
195+
logger.error("Other grouped errors: %s", other.exceptions)
196+
# except IntegrityError as ex:
197+
# logger.error(f"Error adding gateway: {ex}")
198+
# db.rollback()
199+
# raise GatewayError(f"Gateway already exists: {gateway.name}")
200+
# except Exception as e:
201+
# db.rollback()
202+
# raise GatewayError(f"Failed to register gateway: {str(e)}")
177203

178204
async def list_gateways(self, db: Session, include_inactive: bool = False) -> List[GatewayRead]:
179205
"""List all registered gateways.
@@ -393,14 +419,6 @@ async def delete_gateway(self, db: Session, gateway_id: int) -> None:
393419
# Store gateway info for notification before deletion
394420
gateway_info = {"id": gateway.id, "name": gateway.name, "url": gateway.url}
395421

396-
# Remove associated tools
397-
try:
398-
db.query(DbTool).filter(DbTool.gateway_id == gateway_id).delete()
399-
db.commit()
400-
logger.info(f"Deleted tools associated with gateway: {gateway.name}")
401-
except Exception as ex:
402-
logger.warning(f"No tools found: {ex}")
403-
404422
# Hard delete gateway
405423
db.delete(gateway)
406424
db.commit()
@@ -457,28 +475,42 @@ async def forward_request(self, gateway: DbGateway, method: str, params: Optiona
457475
raise GatewayConnectionError(f"Failed to forward request to {gateway.name}: {str(e)}")
458476

459477
async def check_health_of_gateways(self, gateways: List[DbGateway]) -> bool:
460-
"""Health check for gateways
478+
"""Health check for a list of gateways.
479+
480+
Deactivates gateway if gateway is not healthy.
461481
462482
Args:
463-
gateways: Gateways to check
483+
gateways (List[DbGateway]): List of gateways to check if healthy
464484
465485
Returns:
466-
True if gateway is healthy
486+
bool: True if all active gateways are healthy
467487
"""
468-
for gateway in gateways:
469-
if not gateway.is_active:
470-
return False
488+
# Reuse a single HTTP client for all requests
489+
async with httpx.AsyncClient() as client:
490+
for gateway in gateways:
491+
# Inactive gateways are unhealthy
492+
if not gateway.is_active:
493+
continue
471494

472-
try:
473-
# Try to initialize connection
474-
await self._initialize_gateway(gateway.url, gateway.auth_value)
495+
try:
496+
# Ensure auth_value is a dict
497+
auth_data = gateway.auth_value or {}
498+
headers = decode_auth(auth_data)
499+
500+
# Perform the GET and raise on 4xx/5xx
501+
async with client.stream("GET", gateway.url, headers=headers) as response:
502+
# This will raise immediately if status is 4xx/5xx
503+
response.raise_for_status()
475504

476-
# Update last seen
477-
gateway.last_seen = datetime.utcnow()
478-
return True
505+
# Mark successful check
506+
gateway.last_seen = datetime.utcnow()
507+
508+
except Exception:
509+
with SessionLocal() as db:
510+
await self.toggle_gateway_status(db=db, gateway_id=gateway.id, activate=False)
479511

480-
except Exception:
481-
return False
512+
# All gateways passed
513+
return True
482514

483515
async def aggregate_capabilities(self, db: Session) -> Dict[str, Any]:
484516
"""Aggregate capabilities from all gateways.
@@ -579,7 +611,11 @@ async def connect_to_sse_server(server_url: str, authentication: Optional[Dict[s
579611
raise GatewayConnectionError(f"Failed to initialize gateway at {url}: {str(e)}")
580612

581613
def _get_active_gateways(self) -> list[DbGateway]:
582-
"""Sync function for database operations (runs in thread)."""
614+
"""Sync function for database operations (runs in thread).
615+
616+
Returns:
617+
List[DbGateway]: List of active gateways
618+
"""
583619
with SessionLocal() as db:
584620
return db.execute(select(DbGateway).where(DbGateway.is_active)).scalars().all()
585621

@@ -590,9 +626,9 @@ async def _run_health_checks(self) -> None:
590626
# Run sync database code in a thread
591627
gateways = await asyncio.to_thread(self._get_active_gateways)
592628

593-
# Async health checks (non-blocking)
594-
await self.check_health_of_gateways(gateways)
595-
629+
if len(gateways) > 0:
630+
# Async health checks (non-blocking)
631+
await self.check_health_of_gateways(gateways)
596632
except Exception as e:
597633
logger.error(f"Health check run failed: {str(e)}")
598634

mcpgateway/services/tool_service.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,12 +518,15 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any]) -
518518
else:
519519
headers = {}
520520

521-
async def connect_to_sse_server(server_url: str):
521+
async def connect_to_sse_server(server_url: str) -> str:
522522
"""
523523
Connect to an MCP server running with SSE transport
524524
525525
Args:
526526
server_url (str): MCP Server SSE URL
527+
528+
Returns:
529+
str: Result of tool call
527530
"""
528531
# Use async with directly to manage the context
529532
async with sse_client(url=server_url, headers=headers) as streams:

tests/hey/payload2.json

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"jsonrpc": "2.0",
3+
"id": 1,
4+
"method": "tools/call",
5+
"params": {
6+
"name": "convert_time",
7+
"arguments": {
8+
"source_timezone": "Europe/Berlin",
9+
"target_timezone": "Europe/Dublin",
10+
"time": "09:00"
11+
}
12+
}
13+
}

0 commit comments

Comments
 (0)