|
16 | 16 |
|
17 | 17 | import asyncio
|
18 | 18 | import logging
|
| 19 | +import uuid |
19 | 20 | from datetime import datetime, timezone
|
20 | 21 | from typing import Any, AsyncGenerator, Dict, List, Optional, Set
|
21 | 22 |
|
22 | 23 | import httpx
|
| 24 | +from filelock import FileLock, Timeout |
23 | 25 | from mcp import ClientSession
|
24 | 26 | from mcp.client.sse import sse_client
|
25 | 27 | from sqlalchemy import select
|
|
33 | 35 | from mcpgateway.services.tool_service import ToolService
|
34 | 36 | from mcpgateway.utils.services_auth import decode_auth
|
35 | 37 |
|
| 38 | +try: |
| 39 | + import redis |
| 40 | + REDIS_AVAILABLE = True |
| 41 | +except ImportError: |
| 42 | + REDIS_AVAILABLE = False |
| 43 | + logging.info("Redis is not utilized in this environment.") |
| 44 | + |
| 45 | +# logging.getLogger("httpx").setLevel(logging.WARNING) # Disables httpx logs for regular health checks |
36 | 46 | logger = logging.getLogger(__name__)
|
37 | 47 |
|
38 | 48 |
|
| 49 | +GW_FAILURE_THRESHOLD = settings.unhealthy_threshold |
| 50 | +GW_HEALTH_CHECK_INTERVAL = settings.health_check_interval |
| 51 | + |
| 52 | + |
39 | 53 | class GatewayError(Exception):
|
40 | 54 | """Base class for gateway-related errors."""
|
41 | 55 |
|
@@ -83,17 +97,45 @@ def __init__(self):
|
83 | 97 | """Initialize the gateway service."""
|
84 | 98 | self._event_subscribers: List[asyncio.Queue] = []
|
85 | 99 | self._http_client = httpx.AsyncClient(timeout=settings.federation_timeout, verify=not settings.skip_ssl_verify)
|
86 |
| - self._health_check_interval = 60 # seconds |
| 100 | + self._health_check_interval = GW_HEALTH_CHECK_INTERVAL |
87 | 101 | self._health_check_task: Optional[asyncio.Task] = None
|
88 | 102 | self._active_gateways: Set[str] = set() # Track active gateway URLs
|
89 | 103 | self._stream_response = None
|
90 | 104 | self._pending_responses = {}
|
91 | 105 | self.tool_service = ToolService()
|
| 106 | + self._gateway_failure_counts: dict[str, int] = {} |
| 107 | + |
| 108 | + # For health checks, we determine the leader instance. |
| 109 | + self.redis_url = settings.redis_url if settings.cache_type == "redis" else None |
| 110 | + |
| 111 | + if self.redis_url and REDIS_AVAILABLE: |
| 112 | + self._redis_client = redis.from_url(self.redis_url) |
| 113 | + self._instance_id = str(uuid.uuid4()) # Unique ID for this process |
| 114 | + self._leader_key = "gateway_service_leader" |
| 115 | + self._leader_ttl = 40 # seconds |
| 116 | + elif settings.cache_type != "none": |
| 117 | + # Fallback: File-based lock |
| 118 | + self._redis_client = None |
| 119 | + self._lock_path = settings.filelock_path |
| 120 | + self._file_lock = FileLock(self._lock_path) |
92 | 121 |
|
93 | 122 | async def initialize(self) -> None:
|
94 |
| - """Initialize the service.""" |
| 123 | + """Initialize the service and start health check if this instance is the leader.""" |
95 | 124 | logger.info("Initializing gateway service")
|
96 |
| - self._health_check_task = asyncio.create_task(self._run_health_checks()) |
| 125 | + |
| 126 | + if self._redis_client: |
| 127 | + # Check if Redis is available |
| 128 | + pong = self._redis_client.ping() |
| 129 | + if not pong: |
| 130 | + raise ConnectionError("Redis ping failed.") |
| 131 | + |
| 132 | + is_leader = self._redis_client.set(self._leader_key, self._instance_id, ex=self._leader_ttl, nx=True) |
| 133 | + if is_leader: |
| 134 | + logger.info("Acquired Redis leadership. Starting health check task.") |
| 135 | + self._health_check_task = asyncio.create_task(self._run_health_checks()) |
| 136 | + else: |
| 137 | + # Always create the health check task in filelock mode; leader check is handled inside. |
| 138 | + self._health_check_task = asyncio.create_task(self._run_health_checks()) |
97 | 139 |
|
98 | 140 | async def shutdown(self) -> None:
|
99 | 141 | """Shutdown the service."""
|
@@ -474,6 +516,30 @@ async def forward_request(self, gateway: DbGateway, method: str, params: Optiona
|
474 | 516 | except Exception as e:
|
475 | 517 | raise GatewayConnectionError(f"Failed to forward request to {gateway.name}: {str(e)}")
|
476 | 518 |
|
| 519 | + async def _handle_gateway_failure(self, gateway: str) -> None: |
| 520 | + """ |
| 521 | + Tracks and handles gateway failures during health checks. |
| 522 | + If the failure count exceeds the threshold, the gateway is deactivated. |
| 523 | +
|
| 524 | + Args: |
| 525 | + gateway (str): The gateway object that failed its health check. |
| 526 | +
|
| 527 | + Returns: |
| 528 | + None |
| 529 | + """ |
| 530 | + if GW_FAILURE_THRESHOLD == -1: |
| 531 | + return # Gateway failure action disabled |
| 532 | + count = self._gateway_failure_counts.get(gateway.id, 0) + 1 |
| 533 | + self._gateway_failure_counts[gateway.id] = count |
| 534 | + |
| 535 | + logger.warning(f"Gateway {gateway.name} failed health check {count} time(s).") |
| 536 | + |
| 537 | + if count >= GW_FAILURE_THRESHOLD: |
| 538 | + logger.error(f"Gateway {gateway.name} failed {GW_FAILURE_THRESHOLD} times. Deactivating...") |
| 539 | + with SessionLocal() as db: |
| 540 | + await self.toggle_gateway_status(db, gateway.id, False) |
| 541 | + self._gateway_failure_counts[gateway.id] = 0 # Reset after deactivation |
| 542 | + |
477 | 543 | async def check_health_of_gateways(self, gateways: List[DbGateway]) -> bool:
|
478 | 544 | """Health check for a list of gateways.
|
479 | 545 |
|
@@ -506,9 +572,8 @@ async def check_health_of_gateways(self, gateways: List[DbGateway]) -> bool:
|
506 | 572 | gateway.last_seen = datetime.utcnow()
|
507 | 573 |
|
508 | 574 | except Exception:
|
509 |
| - with SessionLocal() as db: |
510 |
| - await self.toggle_gateway_status(db=db, gateway_id=gateway.id, activate=False) |
511 |
| - |
| 575 | + await self._handle_gateway_failure(gateway) |
| 576 | + |
512 | 577 | # All gateways passed
|
513 | 578 | return True
|
514 | 579 |
|
@@ -620,19 +685,68 @@ def _get_active_gateways(self) -> list[DbGateway]:
|
620 | 685 | return db.execute(select(DbGateway).where(DbGateway.is_active)).scalars().all()
|
621 | 686 |
|
622 | 687 | async def _run_health_checks(self) -> None:
|
623 |
| - """Run health checks with sync Session in async code.""" |
| 688 | + """Run health checks periodically, |
| 689 | + Uses Redis or FileLock - for multiple workers. |
| 690 | + Uses simple health check for single worker mode.""" |
| 691 | + |
624 | 692 | while True:
|
625 | 693 | try:
|
626 |
| - # Run sync database code in a thread |
627 |
| - gateways = await asyncio.to_thread(self._get_active_gateways) |
| 694 | + if self._redis_client and settings.cache_type == "redis": |
| 695 | + # Redis-based leader check |
| 696 | + current_leader = self._redis_client.get(self._leader_key) |
| 697 | + if current_leader != self._instance_id.encode(): |
| 698 | + return |
| 699 | + self._redis_client.expire(self._leader_key, self._leader_ttl) |
628 | 700 |
|
629 |
| - if len(gateways) > 0: |
630 |
| - # Async health checks (non-blocking) |
631 |
| - await self.check_health_of_gateways(gateways) |
632 |
| - except Exception as e: |
633 |
| - logger.error(f"Health check run failed: {str(e)}") |
| 701 | + # Run health checks |
| 702 | + gateways = await asyncio.to_thread(self._get_active_gateways) |
| 703 | + if gateways: |
| 704 | + await self.check_health_of_gateways(gateways) |
| 705 | + |
| 706 | + await asyncio.sleep(self._health_check_interval) |
| 707 | + |
| 708 | + elif settings.cache_type == "none": |
| 709 | + try: |
| 710 | + # For single worker mode, run health checks directly |
| 711 | + gateways = await asyncio.to_thread(self._get_active_gateways) |
634 | 712 |
|
635 |
| - await asyncio.sleep(self._health_check_interval) |
| 713 | + if gateways: |
| 714 | + await self.check_health_of_gateways(gateways) |
| 715 | + except Exception as e: |
| 716 | + logger.error(f"Health check run failed: {str(e)}") |
| 717 | + |
| 718 | + await asyncio.sleep(self._health_check_interval) |
| 719 | + |
| 720 | + else: |
| 721 | + # FileLock-based leader fallback |
| 722 | + try: |
| 723 | + self._file_lock.acquire(timeout=0) |
| 724 | + logger.info("File lock acquired. Running health checks.") |
| 725 | + |
| 726 | + while True: |
| 727 | + gateways = await asyncio.to_thread(self._get_active_gateways) |
| 728 | + if gateways: |
| 729 | + await self.check_health_of_gateways(gateways) |
| 730 | + await asyncio.sleep(self._health_check_interval) |
| 731 | + |
| 732 | + except Timeout: |
| 733 | + logger.debug("File lock already held. Retrying later.") |
| 734 | + await asyncio.sleep(self._health_check_interval) |
| 735 | + |
| 736 | + except Exception as e: |
| 737 | + logger.error(f"FileLock health check failed: {str(e)}") |
| 738 | + |
| 739 | + finally: |
| 740 | + if self._file_lock.is_locked: |
| 741 | + try: |
| 742 | + self._file_lock.release() |
| 743 | + logger.info("Released file lock.") |
| 744 | + except Exception as e: |
| 745 | + logger.warning(f"Failed to release file lock: {str(e)}") |
| 746 | + |
| 747 | + except Exception as e: |
| 748 | + logger.error(f"Unexpected error in health check loop: {str(e)}") |
| 749 | + await asyncio.sleep(self._health_check_interval) |
636 | 750 |
|
637 | 751 | def _get_auth_headers(self) -> Dict[str, str]:
|
638 | 752 | """
|
|
0 commit comments