16
16
17
17
import asyncio
18
18
import logging
19
- from datetime import datetime
19
+ from datetime import datetime , timezone
20
20
from typing import Any , AsyncGenerator , Dict , List , Optional , Set
21
21
22
22
import httpx
23
23
from mcp import ClientSession
24
24
from mcp .client .sse import sse_client
25
25
from sqlalchemy import select
26
- from sqlalchemy .exc import IntegrityError
27
26
from sqlalchemy .orm import Session
28
27
29
28
from mcpgateway .config import settings
@@ -122,7 +121,6 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
122
121
123
122
Raises:
124
123
GatewayNameConflictError: If gateway name already exists
125
- GatewayError: If registration fails
126
124
"""
127
125
try :
128
126
# Check for name conflicts (both active and inactive)
@@ -138,12 +136,43 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
138
136
auth_type = getattr (gateway , "auth_type" , None )
139
137
auth_value = getattr (gateway , "auth_value" , {})
140
138
141
- # Initialize connection and get capabilities
142
139
capabilities , tools = await self ._initialize_gateway (str (gateway .url ), auth_value )
143
140
141
+ all_names = [td .name for td in tools ]
142
+
143
+ existing_tools = db .execute (select (DbTool ).where (DbTool .name .in_ (all_names ))).scalars ().all ()
144
+ existing_tool_names = [tool .name for tool in existing_tools ]
145
+
146
+ tools = [
147
+ DbTool (
148
+ name = tool .name ,
149
+ url = str (gateway .url ),
150
+ description = tool .description ,
151
+ integration_type = tool .integration_type ,
152
+ request_type = tool .request_type ,
153
+ headers = tool .headers ,
154
+ input_schema = tool .input_schema ,
155
+ jsonpath_filter = tool .jsonpath_filter ,
156
+ auth_type = auth_type ,
157
+ auth_value = auth_value ,
158
+ )
159
+ for tool in tools
160
+ ]
161
+
162
+ existing_tools = [tool for tool in tools if tool .name in existing_tool_names ]
163
+ new_tools = [tool for tool in tools if tool .name not in existing_tool_names ]
164
+
144
165
# Create DB model
145
166
db_gateway = DbGateway (
146
- name = gateway .name , url = str (gateway .url ), description = gateway .description , capabilities = capabilities , last_seen = datetime .utcnow (), auth_type = auth_type , auth_value = auth_value
167
+ name = gateway .name ,
168
+ url = str (gateway .url ),
169
+ description = gateway .description ,
170
+ capabilities = capabilities ,
171
+ last_seen = datetime .now (timezone .utc ),
172
+ auth_type = auth_type ,
173
+ auth_value = auth_value ,
174
+ tools = new_tools ,
175
+ # federated_tools=existing_tools + new_tools
147
176
)
148
177
149
178
# Add to DB
@@ -157,23 +186,20 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
157
186
# Notify subscribers
158
187
await self ._notify_gateway_added (db_gateway )
159
188
160
- inserted_gateway = db .execute (select (DbGateway ).where (DbGateway .name == gateway .name )).scalar_one_or_none ()
161
- inserted_gateway_id = inserted_gateway .id
162
-
163
- logger .info (f"Registered gateway: { gateway .name } " )
164
-
165
- for tool in tools :
166
- tool .gateway_id = inserted_gateway_id
167
- await self .tool_service .register_tool (db = db , tool = tool )
168
-
169
189
return GatewayRead .model_validate (gateway )
170
-
171
- except IntegrityError :
172
- db .rollback ()
173
- raise GatewayError (f"Gateway already exists: { gateway .name } " )
174
- except Exception as e :
175
- db .rollback ()
176
- raise GatewayError (f"Failed to register gateway: { str (e )} " )
190
+ except* ValueError as ve :
191
+ logger .error ("ValueErrors in group: %s" , ve .exceptions )
192
+ except* RuntimeError as re :
193
+ logger .error ("RuntimeErrors in group: %s" , re .exceptions )
194
+ except* BaseException as other : # catches every other sub-exception
195
+ logger .error ("Other grouped errors: %s" , other .exceptions )
196
+ # except IntegrityError as ex:
197
+ # logger.error(f"Error adding gateway: {ex}")
198
+ # db.rollback()
199
+ # raise GatewayError(f"Gateway already exists: {gateway.name}")
200
+ # except Exception as e:
201
+ # db.rollback()
202
+ # raise GatewayError(f"Failed to register gateway: {str(e)}")
177
203
178
204
async def list_gateways (self , db : Session , include_inactive : bool = False ) -> List [GatewayRead ]:
179
205
"""List all registered gateways.
@@ -393,14 +419,6 @@ async def delete_gateway(self, db: Session, gateway_id: int) -> None:
393
419
# Store gateway info for notification before deletion
394
420
gateway_info = {"id" : gateway .id , "name" : gateway .name , "url" : gateway .url }
395
421
396
- # Remove associated tools
397
- try :
398
- db .query (DbTool ).filter (DbTool .gateway_id == gateway_id ).delete ()
399
- db .commit ()
400
- logger .info (f"Deleted tools associated with gateway: { gateway .name } " )
401
- except Exception as ex :
402
- logger .warning (f"No tools found: { ex } " )
403
-
404
422
# Hard delete gateway
405
423
db .delete (gateway )
406
424
db .commit ()
@@ -457,28 +475,42 @@ async def forward_request(self, gateway: DbGateway, method: str, params: Optiona
457
475
raise GatewayConnectionError (f"Failed to forward request to { gateway .name } : { str (e )} " )
458
476
459
477
async def check_health_of_gateways (self , gateways : List [DbGateway ]) -> bool :
460
- """Health check for gateways
478
+ """Health check for a list of gateways.
479
+
480
+ Deactivates gateway if gateway is not healthy.
461
481
462
482
Args:
463
- gateways: Gateways to check
483
+ gateways (List[DbGateway]): List of gateways to check if healthy
464
484
465
485
Returns:
466
- True if gateway is healthy
486
+ bool: True if all active gateways are healthy
467
487
"""
468
- for gateway in gateways :
469
- if not gateway .is_active :
470
- return False
488
+ # Reuse a single HTTP client for all requests
489
+ async with httpx .AsyncClient () as client :
490
+ for gateway in gateways :
491
+ # Inactive gateways are unhealthy
492
+ if not gateway .is_active :
493
+ continue
471
494
472
- try :
473
- # Try to initialize connection
474
- await self ._initialize_gateway (gateway .url , gateway .auth_value )
495
+ try :
496
+ # Ensure auth_value is a dict
497
+ auth_data = gateway .auth_value or {}
498
+ headers = decode_auth (auth_data )
499
+
500
+ # Perform the GET and raise on 4xx/5xx
501
+ async with client .stream ("GET" , gateway .url , headers = headers ) as response :
502
+ # This will raise immediately if status is 4xx/5xx
503
+ response .raise_for_status ()
475
504
476
- # Update last seen
477
- gateway .last_seen = datetime .utcnow ()
478
- return True
505
+ # Mark successful check
506
+ gateway .last_seen = datetime .utcnow ()
507
+
508
+ except Exception :
509
+ with SessionLocal () as db :
510
+ await self .toggle_gateway_status (db = db , gateway_id = gateway .id , activate = False )
479
511
480
- except Exception :
481
- return False
512
+ # All gateways passed
513
+ return True
482
514
483
515
async def aggregate_capabilities (self , db : Session ) -> Dict [str , Any ]:
484
516
"""Aggregate capabilities from all gateways.
@@ -579,7 +611,11 @@ async def connect_to_sse_server(server_url: str, authentication: Optional[Dict[s
579
611
raise GatewayConnectionError (f"Failed to initialize gateway at { url } : { str (e )} " )
580
612
581
613
def _get_active_gateways (self ) -> list [DbGateway ]:
582
- """Sync function for database operations (runs in thread)."""
614
+ """Sync function for database operations (runs in thread).
615
+
616
+ Returns:
617
+ List[DbGateway]: List of active gateways
618
+ """
583
619
with SessionLocal () as db :
584
620
return db .execute (select (DbGateway ).where (DbGateway .is_active )).scalars ().all ()
585
621
@@ -590,9 +626,9 @@ async def _run_health_checks(self) -> None:
590
626
# Run sync database code in a thread
591
627
gateways = await asyncio .to_thread (self ._get_active_gateways )
592
628
593
- # Async health checks (non-blocking)
594
- await self . check_health_of_gateways ( gateways )
595
-
629
+ if len ( gateways ) > 0 :
630
+ # Async health checks (non-blocking )
631
+ await self . check_health_of_gateways ( gateways )
596
632
except Exception as e :
597
633
logger .error (f"Health check run failed: { str (e )} " )
598
634
0 commit comments