|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +# Standard |
| 3 | +import asyncio |
| 4 | +import sys |
| 5 | + |
| 6 | +# First-Party |
| 7 | +# Import the module under test after patching where necessary |
| 8 | +import mcpgateway.utils.db_isready as db_isready |
| 9 | + |
| 10 | +# Third-Party |
| 11 | +import pytest |
| 12 | + |
| 13 | +# --------------------------------------------------------------------------- |
| 14 | +# Helper test doubles |
| 15 | +# --------------------------------------------------------------------------- |
| 16 | + |
| 17 | + |
| 18 | +class _DummyConn: |
| 19 | + """A no‑op DBAPI connection that always succeeds on ``execute``.""" |
| 20 | + |
| 21 | + def execute(self, _): |
| 22 | + return 1 # pragma: no cover |
| 23 | + |
| 24 | + # Context‑manager support ------------------------------------------------- |
| 25 | + def __enter__(self): |
| 26 | + return self |
| 27 | + |
| 28 | + def __exit__(self, exc_type, exc, tb): |
| 29 | + return False |
| 30 | + |
| 31 | + |
| 32 | +class _DummyEngine: |
| 33 | + """Mimics the minimal SQLAlchemy *Engine* interface needed by db_isready.""" |
| 34 | + |
| 35 | + def __init__(self, succeed_after: int = 1): |
| 36 | + self._attempts = 0 |
| 37 | + self._succeed_after = max(1, succeed_after) |
| 38 | + |
| 39 | + def connect(self): |
| 40 | + # Import inside the method so SQLAlchemy is only required when tests run |
| 41 | + # Third-Party |
| 42 | + from sqlalchemy.exc import OperationalError # pylint: disable=C0415 |
| 43 | + |
| 44 | + self._attempts += 1 |
| 45 | + if self._attempts < self._succeed_after: |
| 46 | + raise OperationalError("SELECT 1", {}, Exception("boom")) |
| 47 | + return _DummyConn() |
| 48 | + |
| 49 | + # Expose attempts for assertions |
| 50 | + @property |
| 51 | + def attempts(self): # noqa: D401 – simple accessor |
| 52 | + return self._attempts |
| 53 | + |
| 54 | + |
| 55 | +# --------------------------------------------------------------------------- |
| 56 | +# Unit‑tests – utilities first, then public API |
| 57 | +# --------------------------------------------------------------------------- |
| 58 | + |
| 59 | + |
| 60 | +@pytest.mark.parametrize( |
| 61 | + "raw", |
| 62 | + [ |
| 63 | + "postgresql://alice:secret@db/mydb", |
| 64 | + "error password=reallys3cret param=value", |
| 65 | + ], |
| 66 | +) |
| 67 | +def test_sanitize_masks_sensitive_parts(raw): |
| 68 | + """Anything that looks like credentials must be replaced by ***.""" |
| 69 | + |
| 70 | + redacted = db_isready._sanitize(raw) |
| 71 | + |
| 72 | + # The replacement text must contain at least one asterisk block signalling masking |
| 73 | + assert "***" in redacted |
| 74 | + |
| 75 | + # And **no** piece of the original secret text may survive |
| 76 | + assert "secret" not in redacted |
| 77 | + assert "reallys3cret" not in redacted |
| 78 | + |
| 79 | + |
| 80 | +@pytest.mark.parametrize( |
| 81 | + "url, expected", |
| 82 | + [ |
| 83 | + ("sqlite:///:memory:", ":memory:"), # SQLAlchemy represents memory DB with literal string |
| 84 | + ("postgresql://u:[email protected]:5432/mcp", "db.example.com:5432/mcp"), |
| 85 | + ], |
| 86 | +) |
| 87 | +def test_format_target_variants(url, expected): |
| 88 | + """_format_target should create concise human readable targets.""" |
| 89 | + |
| 90 | + assert db_isready._format_target(db_isready.make_url(url)) == expected |
| 91 | + |
| 92 | + |
| 93 | +def test_wait_for_db_ready_success(monkeypatch): |
| 94 | + """A healthy database should succeed on the first attempt.""" |
| 95 | + |
| 96 | + dummy = _DummyEngine(succeed_after=1) |
| 97 | + |
| 98 | + def _fake_create_engine(_url, **kwargs): |
| 99 | + _fake_create_engine.kwargs = kwargs # type: ignore[attr-defined] |
| 100 | + return dummy |
| 101 | + |
| 102 | + monkeypatch.setattr(db_isready, "create_engine", _fake_create_engine) |
| 103 | + monkeypatch.setattr(db_isready.time, "sleep", lambda *_: None) |
| 104 | + |
| 105 | + db_isready.wait_for_db_ready( |
| 106 | + database_url="postgresql://user:pw@localhost:5432/mcp", |
| 107 | + max_tries=3, |
| 108 | + interval=0.001, |
| 109 | + timeout=1, |
| 110 | + sync=True, |
| 111 | + ) |
| 112 | + |
| 113 | + assert dummy.attempts == 1 |
| 114 | + assert _fake_create_engine.kwargs["connect_args"]["connect_timeout"] == 1 # type: ignore[attr-defined] |
| 115 | + |
| 116 | + |
| 117 | +def test_wait_for_db_ready_retries_then_succeeds(monkeypatch): |
| 118 | + """OperationalError should trigger retries until the connection works.""" |
| 119 | + |
| 120 | + dummy = _DummyEngine(succeed_after=3) |
| 121 | + monkeypatch.setattr(db_isready, "create_engine", lambda *_a, **_k: dummy) |
| 122 | + monkeypatch.setattr(db_isready.time, "sleep", lambda *_: None) |
| 123 | + |
| 124 | + db_isready.wait_for_db_ready( |
| 125 | + database_url="postgresql://u:p@db/mcp", |
| 126 | + max_tries=5, |
| 127 | + interval=0.0001, |
| 128 | + timeout=2, |
| 129 | + sync=True, |
| 130 | + ) |
| 131 | + |
| 132 | + assert dummy.attempts == 3 |
| 133 | + |
| 134 | + |
| 135 | +def test_wait_for_db_ready_exhausts_and_raises(monkeypatch): |
| 136 | + """After *max_tries* failures the helper must raise RuntimeError.""" |
| 137 | + |
| 138 | + dummy = _DummyEngine(succeed_after=999) |
| 139 | + monkeypatch.setattr(db_isready, "create_engine", lambda *_a, **_k: dummy) |
| 140 | + monkeypatch.setattr(db_isready.time, "sleep", lambda *_: None) |
| 141 | + |
| 142 | + with pytest.raises(RuntimeError, match="Database not ready after 3 attempts"): |
| 143 | + db_isready.wait_for_db_ready( |
| 144 | + database_url="sqlite:///tmp.db", |
| 145 | + max_tries=3, |
| 146 | + interval=0.001, |
| 147 | + timeout=1, |
| 148 | + sync=True, |
| 149 | + ) |
| 150 | + assert dummy.attempts == 3 |
| 151 | + |
| 152 | + |
| 153 | +def test_wait_for_db_ready_invalid_parameters(): |
| 154 | + """Zero or negative timing parameters are rejected immediately.""" |
| 155 | + |
| 156 | + with pytest.raises(RuntimeError): |
| 157 | + db_isready.wait_for_db_ready(max_tries=0) |
| 158 | + with pytest.raises(RuntimeError): |
| 159 | + db_isready.wait_for_db_ready(interval=0) |
| 160 | + with pytest.raises(RuntimeError): |
| 161 | + db_isready.wait_for_db_ready(timeout=0) |
| 162 | + |
| 163 | + |
| 164 | +def test_wait_for_db_ready_async_path(monkeypatch): |
| 165 | + """Async path should off‑load probe into executor without blocking.""" |
| 166 | + |
| 167 | + dummy = _DummyEngine(succeed_after=1) |
| 168 | + monkeypatch.setattr(db_isready, "create_engine", lambda *_a, **_k: dummy) |
| 169 | + monkeypatch.setattr(db_isready.time, "sleep", lambda *_: None) |
| 170 | + |
| 171 | + # Create a dedicated loop so we can patch run_in_executor cleanly |
| 172 | + loop = asyncio.new_event_loop() |
| 173 | + |
| 174 | + async def _fake_run_in_executor(_executor, func, *args): # noqa: D401 |
| 175 | + # Execute the probe synchronously (no thread) then return dummy future |
| 176 | + func(*args) |
| 177 | + fut = loop.create_future() |
| 178 | + fut.set_result(None) |
| 179 | + return fut |
| 180 | + |
| 181 | + loop.run_in_executor = _fake_run_in_executor # type: ignore[assignment] |
| 182 | + monkeypatch.setattr(asyncio, "get_event_loop", lambda: loop) |
| 183 | + |
| 184 | + db_isready.wait_for_db_ready( |
| 185 | + database_url="postgresql://u:p@db/mcp", |
| 186 | + max_tries=2, |
| 187 | + interval=0.001, |
| 188 | + timeout=1, |
| 189 | + sync=False, |
| 190 | + ) |
| 191 | + |
| 192 | + assert dummy.attempts == 1 |
| 193 | + loop.close() |
| 194 | + |
| 195 | + |
| 196 | +def test_parse_cli_roundtrip(monkeypatch): |
| 197 | + """All CLI flags should be parsed into the expected Namespace values.""" |
| 198 | + |
| 199 | + argv = [ |
| 200 | + "db_isready.py", |
| 201 | + "--database-url", |
| 202 | + "postgresql://u:p@db/mcp", |
| 203 | + "--max-tries", |
| 204 | + "7", |
| 205 | + "--interval", |
| 206 | + "0.5", |
| 207 | + "--timeout", |
| 208 | + "3", |
| 209 | + "--log-level", |
| 210 | + "DEBUG", |
| 211 | + ] |
| 212 | + monkeypatch.setattr(sys, "argv", argv) |
| 213 | + |
| 214 | + ns = db_isready._parse_cli() |
| 215 | + assert ns.database_url == "postgresql://u:p@db/mcp" |
| 216 | + assert ns.max_tries == 7 |
| 217 | + assert ns.interval == 0.5 |
| 218 | + assert ns.timeout == 3 |
| 219 | + assert ns.log_level == "DEBUG" |
0 commit comments