|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import asyncio |
| 4 | +from datetime import datetime, timezone |
| 5 | + |
| 6 | +import pytest |
| 7 | +from fastapi import HTTPException |
| 8 | +from starlette.requests import Request |
| 9 | + |
| 10 | +from nimbus.control_plane import app as control_app |
| 11 | + |
| 12 | + |
| 13 | +def _make_request(headers: dict[str, str], client_ip: str = "203.0.113.10") -> Request: |
| 14 | + scope = { |
| 15 | + "type": "http", |
| 16 | + "method": "GET", |
| 17 | + "path": "/test", |
| 18 | + "scheme": "http", |
| 19 | + "client": (client_ip, 12345), |
| 20 | + "headers": [(key.encode("latin-1"), value.encode("latin-1")) for key, value in headers.items()], |
| 21 | + } |
| 22 | + |
| 23 | + async def receive() -> dict: # pragma: no cover - protocol shim |
| 24 | + return {"type": "http.request", "body": b"", "more_body": False} |
| 25 | + |
| 26 | + return Request(scope, receive) |
| 27 | + |
| 28 | + |
| 29 | +def test_default_cache_scope() -> None: |
| 30 | + assert control_app._default_cache_scope(42) == "pull:org-42,push:org-42" |
| 31 | + |
| 32 | + |
| 33 | +def test_validate_webhook_timestamp_accepts_current(monkeypatch) -> None: |
| 34 | + now = int(1_700_000_000) |
| 35 | + result = control_app._validate_webhook_timestamp(str(now), tolerance_seconds=30, now=now) |
| 36 | + assert result == now |
| 37 | + |
| 38 | + |
| 39 | +@pytest.mark.parametrize("value", ["", "not-int"]) |
| 40 | +def test_validate_webhook_timestamp_rejects_invalid(value: str) -> None: |
| 41 | + with pytest.raises(HTTPException) as exc: |
| 42 | + control_app._validate_webhook_timestamp(value, tolerance_seconds=10) |
| 43 | + assert exc.value.status_code == 400 |
| 44 | + |
| 45 | + |
| 46 | +def test_validate_webhook_timestamp_rejects_stale() -> None: |
| 47 | + now = 1_700_000_000 |
| 48 | + with pytest.raises(HTTPException) as exc: |
| 49 | + control_app._validate_webhook_timestamp(str(now - 100), tolerance_seconds=10, now=now) |
| 50 | + assert exc.value.status_code == 409 |
| 51 | + |
| 52 | + |
| 53 | +def test_get_client_ip_without_trusted_proxies() -> None: |
| 54 | + request = _make_request({}, client_ip="198.51.100.7") |
| 55 | + ip = control_app.get_client_ip(request, trusted_proxies=[]) |
| 56 | + assert ip == "198.51.100.7" |
| 57 | + |
| 58 | + |
| 59 | +def test_get_client_ip_with_trusted_proxy() -> None: |
| 60 | + headers = {"x-forwarded-for": "10.0.0.5"} |
| 61 | + request = _make_request(headers, client_ip="192.0.2.1") |
| 62 | + ip = control_app.get_client_ip(request, trusted_proxies=["192.0.2.0/24"]) |
| 63 | + assert ip == "10.0.0.5" |
| 64 | + |
| 65 | + |
| 66 | +def test_get_client_ip_with_untrusted_proxy() -> None: |
| 67 | + headers = {"x-forwarded-for": "10.0.0.5"} |
| 68 | + request = _make_request(headers, client_ip="203.0.113.1") |
| 69 | + ip = control_app.get_client_ip(request, trusted_proxies=["192.0.2.0/24"]) |
| 70 | + assert ip == "203.0.113.1" |
| 71 | + |
| 72 | + |
| 73 | +def test_row_to_ssh_session_parses_strings() -> None: |
| 74 | + row = { |
| 75 | + "session_id": "sess", |
| 76 | + "job_id": 1, |
| 77 | + "agent_id": "agent", |
| 78 | + "host_port": 2222, |
| 79 | + "created_at": datetime(2024, 1, 1, tzinfo=timezone.utc).isoformat(), |
| 80 | + "expires_at": datetime(2024, 1, 1, 1, tzinfo=timezone.utc).isoformat(), |
| 81 | + } |
| 82 | + session = control_app._row_to_ssh_session(row) |
| 83 | + assert session.session_id == "sess" |
| 84 | + assert session.created_at.tzinfo is not None |
| 85 | + assert session.expires_at > session.created_at |
| 86 | + |
| 87 | + |
| 88 | +def test_rate_limiter_allows_within_limit(monkeypatch) -> None: |
| 89 | + limiter = control_app.RateLimiter(limit=2, interval=1.0) |
| 90 | + times = [0.0, 0.1, 0.2] |
| 91 | + |
| 92 | + def fake_time() -> float: |
| 93 | + return times.pop(0) |
| 94 | + |
| 95 | + monkeypatch.setattr(control_app.time, "time", fake_time) |
| 96 | + assert limiter.allow("key") is True |
| 97 | + assert limiter.allow("key") is True |
| 98 | + assert limiter.allow("key") is False |
| 99 | + |
| 100 | + |
| 101 | +def test_rate_limiter_disabled() -> None: |
| 102 | + limiter = control_app.RateLimiter(limit=0, interval=1.0) |
| 103 | + for _ in range(5): |
| 104 | + assert limiter.allow("key") is True |
0 commit comments