Skip to content

Commit ce5160c

Browse files
authored
Merge pull request #44 from IBM/fix-test-tool
Fix test tool and improve stability
2 parents 51300bb + d40fd54 commit ce5160c

File tree

11 files changed

+130
-103
lines changed

11 files changed

+130
-103
lines changed

mcpgateway/cache/session_registry.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,6 @@ def __init__(
129129
database_url: Database connection URL (required for database backend)
130130
session_ttl: Session time-to-live in seconds
131131
message_ttl: Message time-to-live in seconds
132-
133-
Raises:
134-
ValueError: If backend is invalid or required URL is missing
135132
"""
136133
super().__init__(backend=backend, redis_url=redis_url, database_url=database_url, session_ttl=session_ttl, message_ttl=message_ttl)
137134
self._sessions: Dict[str, Any] = {} # Local transport cache
@@ -407,13 +404,15 @@ async def respond(
407404
server_id: Optional[str],
408405
user: json,
409406
session_id: str,
407+
base_url: str,
410408
) -> None:
411409
"""Respond to broadcast message is transport relevant to session_id is found locally
412410
413411
Args:
414412
server_id: Server ID
415413
session_id: Session ID
416414
user: User information
415+
base_url: Base URL for the FastAPI request
417416
418417
"""
419418

@@ -425,7 +424,7 @@ async def respond(
425424
transport = self.get_session_sync(session_id)
426425
if transport:
427426
message = json.loads(self._session_message.get("message"))
428-
await self.generate_response(message=message, transport=transport, server_id=server_id, user=user)
427+
await self.generate_response(message=message, transport=transport, server_id=server_id, user=user, base_url=base_url)
429428

430429
elif self._backend == "redis":
431430
await self._pubsub.subscribe(session_id)
@@ -440,7 +439,7 @@ async def respond(
440439
message = json.loads(message)
441440
transport = self.get_session_sync(session_id)
442441
if transport:
443-
await self.generate_response(message=message, transport=transport, server_id=server_id, user=user)
442+
await self.generate_response(message=message, transport=transport, server_id=server_id, user=user, base_url=base_url)
444443
except asyncio.CancelledError:
445444
logger.info(f"PubSub listener for session {session_id} cancelled")
446445
finally:
@@ -494,7 +493,7 @@ async def message_check_loop(session_id):
494493
transport = self.get_session_sync(session_id)
495494
if transport:
496495
logger.info("Ready to respond")
497-
await self.generate_response(message=message, transport=transport, server_id=server_id, user=user)
496+
await self.generate_response(message=message, transport=transport, server_id=server_id, user=user, base_url=base_url)
498497

499498
await asyncio.to_thread(_db_remove, session_id, record.message)
500499

@@ -671,7 +670,7 @@ async def handle_initialize_logic(self, body: dict) -> InitializeResult:
671670
instructions=("MCP Gateway providing federated tools, resources and prompts. Use /admin interface for configuration."),
672671
)
673672

674-
async def generate_response(self, message: json, transport: SSETransport, server_id: Optional[str], user: dict):
673+
async def generate_response(self, message: json, transport: SSETransport, server_id: Optional[str], user: dict, base_url: str):
675674
"""
676675
Generates response according to SSE specifications
677676
@@ -680,6 +679,7 @@ async def generate_response(self, message: json, transport: SSETransport, server
680679
transport: Transport where message should be responded in
681680
server_id: Server ID
682681
user: User information
682+
base_url: Base URL for the FastAPI request
683683
684684
"""
685685
result = {}
@@ -745,9 +745,10 @@ async def generate_response(self, message: json, transport: SSETransport, server
745745
"id": 1,
746746
}
747747
headers = {"Authorization": f"Bearer {user['token']}", "Content-Type": "application/json"}
748+
rpc_url = base_url + "/rpc"
748749
async with httpx.AsyncClient(timeout=settings.federation_timeout, verify=not settings.skip_ssl_verify) as client:
749750
rpc_response = await client.post(
750-
f"http://localhost:{settings.port}/rpc",
751+
url=rpc_url,
751752
json=rpc_input,
752753
headers=headers,
753754
)

mcpgateway/db.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,3 +1086,7 @@ def init_db():
10861086
Base.metadata.create_all(bind=engine)
10871087
except SQLAlchemyError as e:
10881088
raise Exception(f"Failed to initialize database: {str(e)}")
1089+
1090+
1091+
if __name__ == "__main__":
1092+
init_db()

mcpgateway/handlers/sampling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,9 @@ async def _add_context(self, _db: Session, messages: List[Dict[str, Any]], _cont
150150
"""Add context to messages.
151151
152152
Args:
153-
db: Database session
153+
_db: Database session
154154
messages: Message list
155-
context_type: Context inclusion type
155+
_context_type: Context inclusion type
156156
157157
Returns:
158158
Messages with added context

mcpgateway/main.py

Lines changed: 61 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -157,12 +157,71 @@
157157
# Initialize cache
158158
resource_cache = ResourceCache(max_size=settings.resource_cache_size, ttl=settings.resource_cache_ttl)
159159

160+
161+
####################
162+
# Startup/Shutdown #
163+
####################
164+
@asynccontextmanager
165+
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
166+
"""
167+
Manage the application's startup and shutdown lifecycle.
168+
169+
The function initialises every core service on entry and then
170+
shuts them down in reverse order on exit.
171+
172+
Args:
173+
app (FastAPI): FastAPI app
174+
175+
Yields:
176+
None
177+
178+
Raises:
179+
Exception: Any unhandled error that occurs during service
180+
initialisation or shutdown is re-raised to the caller.
181+
"""
182+
logger.info("Starting MCP Gateway services")
183+
try:
184+
await tool_service.initialize()
185+
await resource_service.initialize()
186+
await prompt_service.initialize()
187+
await gateway_service.initialize()
188+
await root_service.initialize()
189+
await completion_service.initialize()
190+
await logging_service.initialize()
191+
await sampling_handler.initialize()
192+
await resource_cache.initialize()
193+
logger.info("All services initialized successfully")
194+
yield
195+
except Exception as e:
196+
logger.error(f"Error during startup: {str(e)}")
197+
raise
198+
finally:
199+
logger.info("Shutting down MCP Gateway services")
200+
for service in [
201+
resource_cache,
202+
sampling_handler,
203+
logging_service,
204+
completion_service,
205+
root_service,
206+
gateway_service,
207+
prompt_service,
208+
resource_service,
209+
tool_service,
210+
]:
211+
try:
212+
await service.shutdown()
213+
except Exception as e:
214+
logger.error(f"Error shutting down {service.__class__.__name__}: {str(e)}")
215+
logger.info("Shutdown complete")
216+
217+
160218
# Initialize FastAPI app
161219
app = FastAPI(
162220
title=settings.app_name,
163221
version=__version__,
164222
description="A FastAPI-based MCP Gateway with federation support",
165223
root_path=settings.app_root_path,
224+
lifespan=lifespan,
166225
)
167226

168227

@@ -620,7 +679,7 @@ async def sse_endpoint(request: Request, server_id: int, user: str = Depends(req
620679
await session_registry.add_session(transport.session_id, transport)
621680
response = await transport.create_sse_response(request)
622681

623-
asyncio.create_task(session_registry.respond(server_id, user, session_id=transport.session_id))
682+
asyncio.create_task(session_registry.respond(server_id, user, session_id=transport.session_id, base_url=base_url))
624683

625684
tasks = BackgroundTasks()
626685
tasks.add_task(session_registry.remove_session, transport.session_id)
@@ -1744,7 +1803,7 @@ async def utility_sse_endpoint(request: Request, user: str = Depends(require_aut
17441803
await transport.connect()
17451804
await session_registry.add_session(transport.session_id, transport)
17461805

1747-
asyncio.create_task(session_registry.respond(None, user, session_id=transport.session_id))
1806+
asyncio.create_task(session_registry.respond(None, user, session_id=transport.session_id, base_url=base_url))
17481807

17491808
response = await transport.create_sse_response(request)
17501809
tasks = BackgroundTasks()
@@ -1909,60 +1968,6 @@ async def healthcheck(db: Session = Depends(get_db)):
19091968
return {"status": "healthy"}
19101969

19111970

1912-
####################
1913-
# Startup/Shutdown #
1914-
####################
1915-
@asynccontextmanager
1916-
async def lifespan() -> AsyncIterator[None]:
1917-
"""
1918-
Manage the application's startup and shutdown lifecycle.
1919-
1920-
The function initialises every core service on entry and then
1921-
shuts them down in reverse order on exit.
1922-
1923-
Yields:
1924-
None
1925-
1926-
Raises:
1927-
Exception: Any unhandled error that occurs during service
1928-
initialisation or shutdown is re-raised to the caller.
1929-
"""
1930-
logger.info("Starting MCP Gateway services")
1931-
try:
1932-
await tool_service.initialize()
1933-
await resource_service.initialize()
1934-
await prompt_service.initialize()
1935-
await gateway_service.initialize()
1936-
await root_service.initialize()
1937-
await completion_service.initialize()
1938-
await logging_service.initialize()
1939-
await sampling_handler.initialize()
1940-
await resource_cache.initialize()
1941-
logger.info("All services initialized successfully")
1942-
yield
1943-
except Exception as e:
1944-
logger.error(f"Error during startup: {str(e)}")
1945-
raise
1946-
finally:
1947-
logger.info("Shutting down MCP Gateway services")
1948-
for service in [
1949-
resource_cache,
1950-
sampling_handler,
1951-
logging_service,
1952-
completion_service,
1953-
root_service,
1954-
gateway_service,
1955-
prompt_service,
1956-
resource_service,
1957-
tool_service,
1958-
]:
1959-
try:
1960-
await service.shutdown()
1961-
except Exception as e:
1962-
logger.error(f"Error shutting down {service.__class__.__name__}: {str(e)}")
1963-
logger.info("Shutdown complete")
1964-
1965-
19661971
# Mount static files
19671972
app.mount("/static", StaticFiles(directory=str(settings.static_dir)), name="static")
19681973

mcpgateway/services/gateway_service.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from mcpgateway.config import settings
3030
from mcpgateway.db import Gateway as DbGateway
31+
from mcpgateway.db import SessionLocal
3132
from mcpgateway.db import Tool as DbTool
3233
from mcpgateway.schemas import GatewayCreate, GatewayRead, GatewayUpdate, ToolCreate
3334
from mcpgateway.services.tool_service import ToolService
@@ -455,28 +456,29 @@ async def forward_request(self, gateway: DbGateway, method: str, params: Optiona
455456
except Exception as e:
456457
raise GatewayConnectionError(f"Failed to forward request to {gateway.name}: {str(e)}")
457458

458-
async def check_gateway_health(self, gateway: DbGateway) -> bool:
459-
"""Check if a gateway is healthy.
459+
async def check_health_of_gateways(self, gateways: List[DbGateway]) -> bool:
460+
"""Health check for gateways
460461
461462
Args:
462-
gateway: Gateway to check
463+
gateways: Gateways to check
463464
464465
Returns:
465466
True if gateway is healthy
466467
"""
467-
if not gateway.is_active:
468-
return False
468+
for gateway in gateways:
469+
if not gateway.is_active:
470+
return False
469471

470-
try:
471-
# Try to initialize connection
472-
await self._initialize_gateway(gateway.url, gateway.auth_value)
472+
try:
473+
# Try to initialize connection
474+
await self._initialize_gateway(gateway.url, gateway.auth_value)
473475

474-
# Update last seen
475-
gateway.last_seen = datetime.utcnow()
476-
return True
476+
# Update last seen
477+
gateway.last_seen = datetime.utcnow()
478+
return True
477479

478-
except Exception:
479-
return False
480+
except Exception:
481+
return False
480482

481483
async def aggregate_capabilities(self, db: Session) -> Dict[str, Any]:
482484
"""Aggregate capabilities from all gateways.
@@ -576,29 +578,24 @@ async def connect_to_sse_server(server_url: str, authentication: Optional[Dict[s
576578
except Exception as e:
577579
raise GatewayConnectionError(f"Failed to initialize gateway at {url}: {str(e)}")
578580

581+
def _get_active_gateways(self) -> list[DbGateway]:
582+
"""Sync function for database operations (runs in thread)."""
583+
with SessionLocal() as db:
584+
return db.execute(select(DbGateway).where(DbGateway.is_active)).scalars().all()
585+
579586
async def _run_health_checks(self) -> None:
580-
"""Run periodic health checks on all gateways."""
587+
"""Run health checks with sync Session in async code."""
581588
while True:
582589
try:
583-
async with Session() as db:
584-
# Get active gateways
585-
gateways = db.execute(select(DbGateway).where(DbGateway.is_active)).scalars().all()
586-
587-
# Check each gateway
588-
for gateway in gateways:
589-
try:
590-
is_healthy = await self.check_gateway_health(gateway)
591-
if not is_healthy:
592-
logger.warning(f"Gateway {gateway.name} is unhealthy")
593-
except Exception as e:
594-
logger.error(f"Health check failed for {gateway.name}: {str(e)}")
590+
# Run sync database code in a thread
591+
gateways = await asyncio.to_thread(self._get_active_gateways)
595592

596-
db.commit()
593+
# Async health checks (non-blocking)
594+
await self.check_health_of_gateways(gateways)
597595

598596
except Exception as e:
599597
logger.error(f"Health check run failed: {str(e)}")
600598

601-
# Wait for next check
602599
await asyncio.sleep(self._health_check_interval)
603600

604601
def _get_auth_headers(self) -> Dict[str, str]:

mcpgateway/services/resource_service.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,6 @@ async def list_resources(self, db: Session, include_inactive: bool = False) -> L
222222
db (Session): The SQLAlchemy database session.
223223
include_inactive (bool): If True, include inactive resources in the result.
224224
Defaults to False.
225-
cursor (Optional[str], optional): An opaque cursor token for pagination. Currently,
226-
this parameter is ignored. Defaults to None.
227225
228226
Returns:
229227
List[ResourceRead]: A list of resources represented as ResourceRead objects.
@@ -249,8 +247,6 @@ async def list_server_resources(self, db: Session, server_id: int, include_inact
249247
server_id (int): Server ID
250248
include_inactive (bool): If True, include inactive resources in the result.
251249
Defaults to False.
252-
cursor (Optional[str], optional): An opaque cursor token for pagination. Currently,
253-
this parameter is ignored. Defaults to None.
254250
255251
Returns:
256252
List[ResourceRead]: A list of resources represented as ResourceRead objects.

mcpgateway/services/tool_service.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,9 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any]) -
521521
async def connect_to_sse_server(server_url: str):
522522
"""
523523
Connect to an MCP server running with SSE transport
524+
525+
Args:
526+
server_url (str): MCP Server SSE URL
524527
"""
525528
# Use async with directly to manage the context
526529
async with sse_client(url=server_url, headers=headers) as streams:

0 commit comments

Comments
 (0)