23
23
from mcp import ClientSession
24
24
from mcp .client .sse import sse_client
25
25
from sqlalchemy import select
26
- from sqlalchemy .exc import IntegrityError
27
26
from sqlalchemy .orm import Session
28
27
29
28
from mcpgateway .config import settings
@@ -122,7 +121,6 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
122
121
123
122
Raises:
124
123
GatewayNameConflictError: If gateway name already exists
125
- GatewayError: If registration fails
126
124
"""
127
125
try :
128
126
# Check for name conflicts (both active and inactive)
@@ -138,13 +136,17 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
138
136
auth_type = getattr (gateway , "auth_type" , None )
139
137
auth_value = getattr (gateway , "auth_value" , {})
140
138
141
- # Initialize connection and get capabilities
142
139
capabilities , tools = await self ._initialize_gateway (str (gateway .url ), auth_value )
143
-
140
+
141
+ all_names = [td .name for td in tools ]
142
+
143
+ existing_tools = db .execute (select (DbTool ).where (DbTool .name .in_ (all_names ))).scalars ().all ()
144
+ existing_tool_names = [tool .name for tool in existing_tools ]
145
+
144
146
tools = [
145
147
DbTool (
146
148
name = tool .name ,
147
- url = tool .url ,
149
+ url = str ( gateway .url ) ,
148
150
description = tool .description ,
149
151
integration_type = tool .integration_type ,
150
152
request_type = tool .request_type ,
@@ -157,6 +159,9 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
157
159
for tool in tools
158
160
]
159
161
162
+ existing_tools = [tool for tool in tools if tool .name in existing_tool_names ]
163
+ new_tools = [tool for tool in tools if tool .name not in existing_tool_names ]
164
+
160
165
# Create DB model
161
166
db_gateway = DbGateway (
162
167
name = gateway .name ,
@@ -166,7 +171,8 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
166
171
last_seen = datetime .now (timezone .utc ),
167
172
auth_type = auth_type ,
168
173
auth_value = auth_value ,
169
- tools = tools ,
174
+ tools = new_tools ,
175
+ # federated_tools=existing_tools + new_tools
170
176
)
171
177
172
178
# Add to DB
@@ -181,12 +187,19 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
181
187
await self ._notify_gateway_added (db_gateway )
182
188
183
189
return GatewayRead .model_validate (gateway )
184
- except IntegrityError :
185
- db .rollback ()
186
- raise GatewayError (f"Gateway already exists: { gateway .name } " )
187
- except Exception as e :
188
- db .rollback ()
189
- raise GatewayError (f"Failed to register gateway: { str (e )} " )
190
+ except* ValueError as ve :
191
+ logger .error ("ValueErrors in group: %s" , ve .exceptions )
192
+ except* RuntimeError as re :
193
+ logger .error ("RuntimeErrors in group: %s" , re .exceptions )
194
+ except* BaseException as other : # catches every other sub-exception
195
+ logger .error ("Other grouped errors: %s" , other .exceptions )
196
+ # except IntegrityError as ex:
197
+ # logger.error(f"Error adding gateway: {ex}")
198
+ # db.rollback()
199
+ # raise GatewayError(f"Gateway already exists: {gateway.name}")
200
+ # except Exception as e:
201
+ # db.rollback()
202
+ # raise GatewayError(f"Failed to register gateway: {str(e)}")
190
203
191
204
async def list_gateways (self , db : Session , include_inactive : bool = False ) -> List [GatewayRead ]:
192
205
"""List all registered gateways.
@@ -462,28 +475,42 @@ async def forward_request(self, gateway: DbGateway, method: str, params: Optiona
462
475
raise GatewayConnectionError (f"Failed to forward request to { gateway .name } : { str (e )} " )
463
476
464
477
async def check_health_of_gateways (self , gateways : List [DbGateway ]) -> bool :
465
- """Health check for gateways
478
+ """Health check for a list of gateways.
479
+
480
+ Deactivates gateway if gateway is not healthy.
466
481
467
482
Args:
468
- gateways: Gateways to check
483
+ gateways (List[DbGateway]): List of gateways to check if healthy
469
484
470
485
Returns:
471
- True if gateway is healthy
486
+ bool: True if all active gateways are healthy
472
487
"""
473
- for gateway in gateways :
474
- if not gateway .is_active :
475
- return False
488
+ # Reuse a single HTTP client for all requests
489
+ async with httpx .AsyncClient () as client :
490
+ for gateway in gateways :
491
+ # Inactive gateways are unhealthy
492
+ if not gateway .is_active :
493
+ continue
476
494
477
- try :
478
- # Try to initialize connection
479
- await self ._initialize_gateway (gateway .url , gateway .auth_value )
495
+ try :
496
+ # Ensure auth_value is a dict
497
+ auth_data = gateway .auth_value or {}
498
+ headers = decode_auth (auth_data )
499
+
500
+ # Perform the GET and raise on 4xx/5xx
501
+ async with client .stream ("GET" , gateway .url , headers = headers ) as response :
502
+ # This will raise immediately if status is 4xx/5xx
503
+ response .raise_for_status ()
504
+
505
+ # Mark successful check
506
+ gateway .last_seen = datetime .utcnow ()
480
507
481
- # Update last seen
482
- gateway . last_seen = datetime . utcnow ()
483
- return True
508
+ except Exception :
509
+ with SessionLocal () as db :
510
+ await self . toggle_gateway_status ( db = db , gateway_id = gateway . id , activate = False )
484
511
485
- except Exception :
486
- return False
512
+ # All gateways passed
513
+ return True
487
514
488
515
async def aggregate_capabilities (self , db : Session ) -> Dict [str , Any ]:
489
516
"""Aggregate capabilities from all gateways.
@@ -584,7 +611,11 @@ async def connect_to_sse_server(server_url: str, authentication: Optional[Dict[s
584
611
raise GatewayConnectionError (f"Failed to initialize gateway at { url } : { str (e )} " )
585
612
586
613
def _get_active_gateways (self ) -> list [DbGateway ]:
587
- """Sync function for database operations (runs in thread)."""
614
+ """Sync function for database operations (runs in thread).
615
+
616
+ Returns:
617
+ List[DbGateway]: List of active gateways
618
+ """
588
619
with SessionLocal () as db :
589
620
return db .execute (select (DbGateway ).where (DbGateway .is_active )).scalars ().all ()
590
621
@@ -598,7 +629,6 @@ async def _run_health_checks(self) -> None:
598
629
if len (gateways ) > 0 :
599
630
# Async health checks (non-blocking)
600
631
await self .check_health_of_gateways (gateways )
601
-
602
632
except Exception as e :
603
633
logger .error (f"Health check run failed: { str (e )} " )
604
634
0 commit comments