|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | +"""db_isready - Wait until the configured database is ready |
| 4 | +========================================================== |
| 5 | +This helper blocks until the given database (defined by an **SQLAlchemy** URL) |
| 6 | +successfully answers a trivial round-trip - ``SELECT 1`` - and then returns. |
| 7 | +It is useful as a container **readiness/health probe** or imported from Python |
| 8 | +code to delay start-up of services that depend on the DB. |
| 9 | +
|
| 10 | +Exit codes when executed as a script |
| 11 | +----------------------------------- |
| 12 | +* ``0`` - database ready. |
| 13 | +* ``1`` - all attempts exhausted / timed-out. |
| 14 | +* ``2`` - :pypi:`SQLAlchemy` is **not** installed. |
| 15 | +* ``3`` - invalid parameter combination (``max_tries``/``interval``/``timeout``). |
| 16 | +
|
| 17 | +Features |
| 18 | +-------- |
| 19 | +* Accepts **any** SQLAlchemy URL supported by the installed version. |
| 20 | +* Timing knobs (tries, interval, connect-timeout) configurable through |
| 21 | + *environment variables* **or** *CLI flags* - see below. |
| 22 | +* Works **synchronously** (blocking) or **asynchronously** - simply |
| 23 | + ``await wait_for_db_ready()``. |
| 24 | +* Credentials appearing in log lines are automatically **redacted**. |
| 25 | +* Depends only on ``sqlalchemy`` (already required by *mcpgateway*). |
| 26 | +
|
| 27 | +Environment variables |
| 28 | +--------------------- |
| 29 | +The script falls back to :pydata:`mcpgateway.config.settings`, but the values |
| 30 | +below can be overridden via environment variables *or* the corresponding |
| 31 | +command-line options. |
| 32 | +
|
| 33 | ++------------------------+----------------------------------------------+-----------+ |
| 34 | +| Name | Description | Default | |
| 35 | ++========================+==============================================+===========+ |
| 36 | +| ``DATABASE_URL`` | SQLAlchemy connection URL | ``sqlite:///./mcp.db`` | |
| 37 | +| ``DB_WAIT_MAX_TRIES`` | Maximum attempts before giving up | ``30`` | |
| 38 | +| ``DB_WAIT_INTERVAL`` | Delay between attempts *(seconds)* | ``2`` | |
| 39 | +| ``DB_CONNECT_TIMEOUT`` | Per-attempt connect timeout *(seconds)* | ``2`` | |
| 40 | +| ``LOG_LEVEL`` | Log verbosity when not set via ``--log-level`` | ``INFO`` | |
| 41 | ++------------------------+----------------------------------------------+-----------+ |
| 42 | +
|
| 43 | +Usage examples |
| 44 | +-------------- |
| 45 | +Shell :: |
| 46 | +
|
| 47 | + python db_isready.py |
| 48 | + python db_isready.py --database-url "postgresql://user:pw@db:5432/mcp" \ |
| 49 | + --max-tries 20 --interval 1 --timeout 1 |
| 50 | +
|
| 51 | +Python :: |
| 52 | +
|
| 53 | + from db_isready import wait_for_db_ready |
| 54 | +
|
| 55 | + await wait_for_db_ready() # asynchronous |
| 56 | + wait_for_db_ready(sync=True) # synchronous / blocking |
| 57 | +""" |
| 58 | +# Future |
| 59 | +from __future__ import annotations |
| 60 | + |
| 61 | +# Standard |
| 62 | +# --------------------------------------------------------------------------- |
| 63 | +# Standard library imports |
| 64 | +# --------------------------------------------------------------------------- |
| 65 | +import argparse |
| 66 | +import asyncio |
| 67 | +import logging |
| 68 | +import os |
| 69 | +import re |
| 70 | +import sys |
| 71 | +import time |
| 72 | +from typing import Any, Dict, Final, Optional |
| 73 | + |
| 74 | +# --------------------------------------------------------------------------- |
| 75 | +# Third-party imports - abort early if SQLAlchemy is missing |
| 76 | +# --------------------------------------------------------------------------- |
| 77 | +try: |
| 78 | + # Third-Party |
| 79 | + from sqlalchemy import create_engine, text |
| 80 | + from sqlalchemy.engine import Engine, URL |
| 81 | + from sqlalchemy.engine.url import make_url |
| 82 | + from sqlalchemy.exc import OperationalError |
| 83 | +except ImportError: # pragma: no cover - handled at runtime for the CLI |
| 84 | + sys.stderr.write("SQLAlchemy not installed - aborting (pip install sqlalchemy)\n") |
| 85 | + sys.exit(2) |
| 86 | + |
| 87 | +# --------------------------------------------------------------------------- |
| 88 | +# Optional project settings (silently ignored if mcpgateway package is absent) |
| 89 | +# --------------------------------------------------------------------------- |
| 90 | +try: |
| 91 | + # First-Party |
| 92 | + from mcpgateway.config import settings |
| 93 | +except Exception: # pragma: no cover - fallback minimal settings |
| 94 | + |
| 95 | + class _Settings: |
| 96 | + """Fallback dummy settings when *mcpgateway* is not import-able.""" |
| 97 | + |
| 98 | + database_url: str = "sqlite:///./mcp.db" |
| 99 | + log_level: str = "INFO" |
| 100 | + |
| 101 | + settings = _Settings() # type: ignore |
| 102 | + |
| 103 | +# --------------------------------------------------------------------------- |
| 104 | +# Environment variable names |
| 105 | +# --------------------------------------------------------------------------- |
| 106 | +ENV_DB_URL: Final[str] = "DATABASE_URL" |
| 107 | +ENV_MAX_TRIES: Final[str] = "DB_WAIT_MAX_TRIES" |
| 108 | +ENV_INTERVAL: Final[str] = "DB_WAIT_INTERVAL" |
| 109 | +ENV_TIMEOUT: Final[str] = "DB_CONNECT_TIMEOUT" |
| 110 | + |
| 111 | +# --------------------------------------------------------------------------- |
| 112 | +# Defaults - overridable via env-vars or CLI flags |
| 113 | +# --------------------------------------------------------------------------- |
| 114 | +DEFAULT_DB_URL: Final[str] = os.getenv(ENV_DB_URL, settings.database_url) |
| 115 | +DEFAULT_MAX_TRIES: Final[int] = int(os.getenv(ENV_MAX_TRIES, "30")) |
| 116 | +DEFAULT_INTERVAL: Final[float] = float(os.getenv(ENV_INTERVAL, "2")) |
| 117 | +DEFAULT_TIMEOUT: Final[int] = int(os.getenv(ENV_TIMEOUT, "2")) |
| 118 | +DEFAULT_LOG_LEVEL: Final[str] = os.getenv("LOG_LEVEL", settings.log_level).upper() |
| 119 | + |
| 120 | +# --------------------------------------------------------------------------- |
| 121 | +# Helpers - sanitising / formatting util functions |
| 122 | +# --------------------------------------------------------------------------- |
| 123 | +_CRED_RE: Final[re.Pattern[str]] = re.compile(r"://([^:/?#]+):([^@]+)@") |
| 124 | +_PWD_RE: Final[re.Pattern[str]] = re.compile(r"(?i)(password|pwd)=([^\s]+)") |
| 125 | + |
| 126 | + |
| 127 | +def _sanitize(txt: str) -> str: |
| 128 | + """Hide credentials contained in connection strings or driver errors. |
| 129 | +
|
| 130 | + Args: |
| 131 | + txt: Arbitrary text that may contain a DB DSN or ``password=…`` |
| 132 | + parameter. |
| 133 | +
|
| 134 | + Returns: |
| 135 | + Same *txt* but with credentials replaced by ``***``. |
| 136 | + """ |
| 137 | + |
| 138 | + redacted = _CRED_RE.sub(r"://\\1:***@", txt) |
| 139 | + return _PWD_RE.sub(r"\\1=***", redacted) |
| 140 | + |
| 141 | + |
| 142 | +def _format_target(url: URL) -> str: |
| 143 | + """Return a concise *host[:port]/db* representation for logging. |
| 144 | +
|
| 145 | + Args: |
| 146 | + url: A parsed :class:`sqlalchemy.engine.url.URL` instance. |
| 147 | +
|
| 148 | + Returns: |
| 149 | + Human-readable connection target string suitable for log messages. |
| 150 | + """ |
| 151 | + |
| 152 | + if url.get_backend_name() == "sqlite": |
| 153 | + return url.database or "<memory>" |
| 154 | + |
| 155 | + host: str = url.host or "localhost" |
| 156 | + port: str = f":{url.port}" if url.port else "" |
| 157 | + db: str = f"/{url.database}" if url.database else "" |
| 158 | + return f"{host}{port}{db}" |
| 159 | + |
| 160 | + |
| 161 | +# --------------------------------------------------------------------------- |
| 162 | +# Public API - *wait_for_db_ready* |
| 163 | +# --------------------------------------------------------------------------- |
| 164 | + |
| 165 | + |
| 166 | +def wait_for_db_ready( |
| 167 | + *, |
| 168 | + database_url: str = DEFAULT_DB_URL, |
| 169 | + max_tries: int = DEFAULT_MAX_TRIES, |
| 170 | + interval: float = DEFAULT_INTERVAL, |
| 171 | + timeout: int = DEFAULT_TIMEOUT, |
| 172 | + logger: Optional[logging.Logger] = None, |
| 173 | + sync: bool = False, |
| 174 | +) -> None: |
| 175 | + """Block until the database replies to ``SELECT 1``. |
| 176 | +
|
| 177 | + The helper can be awaited **asynchronously** *or* called in *blocking* |
| 178 | + mode by passing ``sync=True``. |
| 179 | +
|
| 180 | + Args: |
| 181 | + database_url: SQLAlchemy URL to probe. Falls back to ``$DATABASE_URL`` |
| 182 | + or the project default (usually an on-disk SQLite file). |
| 183 | + max_tries: Total number of connection attempts before giving up. |
| 184 | + interval: Delay *in seconds* between attempts. |
| 185 | + timeout: Per-attempt connection timeout in seconds (passed to the DB |
| 186 | + driver when supported). |
| 187 | + logger: Optional custom :class:`logging.Logger`. If omitted, a default |
| 188 | + one named ``"db_isready"`` is lazily configured. |
| 189 | + sync: When *True*, run in the **current** thread instead of scheduling |
| 190 | + the probe inside an executor. Setting this flag from inside a |
| 191 | + running event-loop will block that loop! |
| 192 | +
|
| 193 | + Raises: |
| 194 | + RuntimeError: If *invalid* parameters are supplied or the database is |
| 195 | + still unavailable after the configured number of attempts. |
| 196 | + """ |
| 197 | + |
| 198 | + log = logger or logging.getLogger("db_isready") |
| 199 | + if not log.handlers: # basicConfig **once** - respects *log.setLevel* later |
| 200 | + logging.basicConfig( |
| 201 | + level=getattr(logging, DEFAULT_LOG_LEVEL, logging.INFO), |
| 202 | + format="%(asctime)s [%(levelname)s] %(message)s", |
| 203 | + datefmt="%Y-%m-%dT%H:%M:%S", |
| 204 | + ) |
| 205 | + |
| 206 | + if max_tries < 1 or interval <= 0 or timeout <= 0: |
| 207 | + raise RuntimeError("Invalid max_tries / interval / timeout values") |
| 208 | + |
| 209 | + url_obj: URL = make_url(database_url) |
| 210 | + backend: str = url_obj.get_backend_name() |
| 211 | + target: str = _format_target(url_obj) |
| 212 | + |
| 213 | + log.info(f"Probing {backend} at {target} (timeout={timeout}s, interval={interval}s, max_tries={max_tries})") |
| 214 | + |
| 215 | + connect_args: Dict[str, Any] = {} |
| 216 | + if backend.startswith(("postgresql", "mysql")): |
| 217 | + # Most drivers honour this parameter - harmless for others. |
| 218 | + connect_args["connect_timeout"] = timeout |
| 219 | + |
| 220 | + engine: Engine = create_engine( |
| 221 | + database_url, |
| 222 | + pool_pre_ping=True, |
| 223 | + pool_size=1, |
| 224 | + max_overflow=0, |
| 225 | + connect_args=connect_args, |
| 226 | + ) |
| 227 | + |
| 228 | + def _probe() -> None: # noqa: D401 - internal helper |
| 229 | + """Inner synchronous probe running in either the current or a thread. |
| 230 | +
|
| 231 | + Returns: |
| 232 | + None - the function exits successfully once the DB answers. |
| 233 | +
|
| 234 | + Raises: |
| 235 | + RuntimeError: Forwarded after exhausting ``max_tries`` attempts. |
| 236 | + """ |
| 237 | + |
| 238 | + start = time.perf_counter() |
| 239 | + for attempt in range(1, max_tries + 1): |
| 240 | + try: |
| 241 | + with engine.connect() as conn: |
| 242 | + conn.execute(text("SELECT 1")) |
| 243 | + elapsed = time.perf_counter() - start |
| 244 | + log.info(f"Database ready after {elapsed:.2f}s (attempt {attempt})") |
| 245 | + return |
| 246 | + except OperationalError as exc: |
| 247 | + log.debug(f"Attempt {attempt}/{max_tries} failed ({_sanitize(str(exc))}) - retrying in {interval:.1f}s") |
| 248 | + time.sleep(interval) |
| 249 | + raise RuntimeError(f"Database not ready after {max_tries} attempts") |
| 250 | + |
| 251 | + if sync: |
| 252 | + _probe() |
| 253 | + else: |
| 254 | + loop = asyncio.get_event_loop() |
| 255 | + # Off-load to default executor to avoid blocking the event-loop. |
| 256 | + loop.run_until_complete(loop.run_in_executor(None, _probe)) |
| 257 | + |
| 258 | + |
| 259 | +# --------------------------------------------------------------------------- |
| 260 | +# CLI helpers |
| 261 | +# --------------------------------------------------------------------------- |
| 262 | + |
| 263 | + |
| 264 | +def _parse_cli() -> argparse.Namespace: |
| 265 | + """Parse command-line arguments for the *db_isready* CLI wrapper. |
| 266 | +
|
| 267 | + Returns: |
| 268 | + Parsed :class:`argparse.Namespace` holding all CLI options. |
| 269 | + """ |
| 270 | + |
| 271 | + parser = argparse.ArgumentParser( |
| 272 | + description="Wait until the configured database is ready.", |
| 273 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| 274 | + ) |
| 275 | + parser.add_argument( |
| 276 | + "--database-url", |
| 277 | + default=DEFAULT_DB_URL, |
| 278 | + help="SQLAlchemy URL (env DATABASE_URL)", |
| 279 | + ) |
| 280 | + parser.add_argument("--max-tries", type=int, default=DEFAULT_MAX_TRIES, help="Maximum connection attempts") |
| 281 | + parser.add_argument("--interval", type=float, default=DEFAULT_INTERVAL, help="Delay between attempts in seconds") |
| 282 | + parser.add_argument("--timeout", type=int, default=DEFAULT_TIMEOUT, help="Per-attempt connect timeout in seconds") |
| 283 | + parser.add_argument("--log-level", default=DEFAULT_LOG_LEVEL, help="Logging level (DEBUG, INFO, …)") |
| 284 | + return parser.parse_args() |
| 285 | + |
| 286 | + |
| 287 | +def main() -> None: # pragma: no cover |
| 288 | + """CLI entry-point. |
| 289 | +
|
| 290 | + * Parses command-line options. |
| 291 | + * Applies ``--log-level`` to the *db_isready* logger **before** the first |
| 292 | + message is emitted. |
| 293 | + * Delegates the actual probing to :func:`wait_for_db_ready`. |
| 294 | + * Exits with: |
| 295 | +
|
| 296 | + * ``0`` - database became ready. |
| 297 | + * ``1`` - connection attempts exhausted. |
| 298 | + * ``2`` - SQLAlchemy missing (handled on import). |
| 299 | + * ``3`` - invalid parameter combination. |
| 300 | + """ |
| 301 | + cli_args = _parse_cli() |
| 302 | + |
| 303 | + log = logging.getLogger("db_isready") |
| 304 | + log.setLevel(cli_args.log_level.upper()) |
| 305 | + |
| 306 | + try: |
| 307 | + wait_for_db_ready( |
| 308 | + database_url=cli_args.database_url, |
| 309 | + max_tries=cli_args.max_tries, |
| 310 | + interval=cli_args.interval, |
| 311 | + timeout=cli_args.timeout, |
| 312 | + sync=True, |
| 313 | + logger=log, |
| 314 | + ) |
| 315 | + except RuntimeError as exc: |
| 316 | + log.error(f"Database unavailable: {exc}") |
| 317 | + sys.exit(1) |
| 318 | + |
| 319 | + sys.exit(0) |
| 320 | + |
| 321 | + |
| 322 | +if __name__ == "__main__": # pragma: no cover |
| 323 | + main() |
0 commit comments