diff --git a/src/rots/commands/proxy/__init__.py b/src/rots/commands/proxy/__init__.py index 8973324..580266e 100644 --- a/src/rots/commands/proxy/__init__.py +++ b/src/rots/commands/proxy/__init__.py @@ -2,11 +2,12 @@ """Proxy management commands for OTS containers.""" -from .app import app, diff, reload, render, trace +from .app import app, diff, probe, reload, render, trace __all__ = [ "app", "diff", + "probe", "reload", "render", "trace", diff --git a/src/rots/commands/proxy/_helpers.py b/src/rots/commands/proxy/_helpers.py index dbb13c6..5331239 100644 --- a/src/rots/commands/proxy/_helpers.py +++ b/src/rots/commands/proxy/_helpers.py @@ -14,6 +14,7 @@ import contextlib import copy +import dataclasses import json import socket import subprocess @@ -21,6 +22,7 @@ import threading import time import urllib.parse +from datetime import UTC, datetime from http.server import BaseHTTPRequestHandler, HTTPServer from pathlib import Path from typing import TYPE_CHECKING @@ -37,6 +39,27 @@ class ProxyError(Exception): """Error during proxy configuration.""" +@dataclasses.dataclass +class ProbeResult: + """Result of probing a URL with curl.""" + + url: str + http_code: int + ssl_verify_result: int # 0 = valid chain + ssl_verify_ok: bool + cert_issuer: str + cert_subject: str + cert_expiry: str + http_version: str + time_namelookup: float # seconds + time_connect: float + time_appconnect: float # TLS handshake complete + time_starttransfer: float # TTFB + time_total: float + response_headers: dict[str, list[str]] + curl_json: dict # raw write-out for --json passthrough + + def parse_trace_url(url: str) -> urllib.parse.ParseResult: """Normalise and validate a URL for ``proxy trace``. @@ -493,3 +516,281 @@ def _read_stderr() -> str: except subprocess.TimeoutExpired: proc.kill() proc.wait(timeout=3) + + +# --------------------------------------------------------------------------- +# Probe helpers +# --------------------------------------------------------------------------- + +_CURL_SENTINEL = "%%CURL_JSON%%" + + +def build_curl_args( + url: str, + *, + resolve: str | None = None, + connect_to: str | None = None, + cacert: Path | None = None, + cert_status: bool = False, + extra_headers: tuple[str, ...] = (), + timeout: int = 30, + method: str | None = None, + insecure: bool = False, + follow_redirects: bool = False, +) -> list[str]: + """Build the curl command list for probing *url*. + + Returns the full argv list without executing anything — purely + testable by asserting on the returned list. + """ + cmd = [ + "curl", + "-sS", + "-o", + "/dev/null", + "-D", + "-", + "-w", + f"\n{_CURL_SENTINEL}\n%{{json}}", + "--max-time", + str(timeout), + ] + + if resolve is not None: + cmd.extend(["--resolve", resolve]) + if connect_to is not None: + cmd.extend(["--connect-to", connect_to]) + if cacert is not None: + cmd.extend(["--cacert", str(cacert)]) + if cert_status: + cmd.append("--cert-status") + for h in extra_headers: + cmd.extend(["-H", h]) + if method is not None: + cmd.extend(["-X", method]) + if insecure: + cmd.append("-k") + if follow_redirects: + cmd.append("-L") + + cmd.append("--") + cmd.append(url) + return cmd + + +def parse_curl_output(stdout: str) -> ProbeResult: + """Parse combined curl output (``-D -`` headers + ``-w '%{json}'``). + + The output is split on the ``%%CURL_JSON%%`` sentinel. The first + part contains HTTP response headers; the second is the JSON blob + from curl's ``--write-out '%{json}'``. + + Raises: + ProxyError: When the sentinel is missing or JSON is malformed. + """ + if _CURL_SENTINEL not in stdout: + raise ProxyError("curl output missing sentinel — unexpected format") + + header_section, json_section = stdout.split(_CURL_SENTINEL, 1) + + # When curl follows redirects (-L), -D - emits multiple header blocks + # (one per hop) separated by blank lines. Only parse the final block + # so assertions and output reflect the terminal response. + normalized = header_section.replace("\r\n", "\n") + header_blocks = [b for b in normalized.split("\n\n") if b.strip()] + final_block = header_blocks[-1] if header_blocks else "" + + response_headers: dict[str, list[str]] = {} + for line in final_block.splitlines(): + if ":" in line and not line.startswith("HTTP/"): + key, _, value = line.partition(":") + key = key.strip() + response_headers.setdefault(key, []).append(value.strip()) + + try: + curl_json = json.loads(json_section.strip()) + except (json.JSONDecodeError, ValueError) as e: + raise ProxyError(f"curl JSON output malformed: {e}") from e + + # Extract cert details from the certs string + certs_str = curl_json.get("certs", "") + cert_issuer = "" + cert_subject = "" + cert_expiry = "" + for cert_line in certs_str.splitlines(): + stripped = cert_line.strip() + if stripped.startswith("Issuer:") and not cert_issuer: + cert_issuer = stripped[len("Issuer:") :].strip() + elif stripped.startswith("Subject:") and not cert_subject: + cert_subject = stripped[len("Subject:") :].strip() + elif stripped.startswith("Expire date:") and not cert_expiry: + cert_expiry = stripped[len("Expire date:") :].strip() + + ssl_verify = curl_json.get("ssl_verify_result", -1) + return ProbeResult( + url=curl_json.get("url_effective", curl_json.get("url", "")), + http_code=curl_json.get("http_code", 0), + ssl_verify_result=ssl_verify, + ssl_verify_ok=ssl_verify == 0, + cert_issuer=cert_issuer, + cert_subject=cert_subject, + cert_expiry=cert_expiry, + http_version=curl_json.get("http_version", ""), + time_namelookup=curl_json.get("time_namelookup", 0.0), + time_connect=curl_json.get("time_connect", 0.0), + time_appconnect=curl_json.get("time_appconnect", 0.0), + time_starttransfer=curl_json.get("time_starttransfer", 0.0), + time_total=curl_json.get("time_total", 0.0), + response_headers=response_headers, + curl_json=curl_json, + ) + + +def _parse_cert_expiry_days(cert_expiry: str) -> int | None: + """Parse cert expiry string and return days remaining. + + The format comes from curl's ``%{json}`` output, e.g., + ``"Aug 17 23:59:59 2026 GMT"``. + + Returns None on empty string or parse failure. + """ + if not cert_expiry: + return None + try: + expiry = datetime.strptime(cert_expiry, "%b %d %H:%M:%S %Y %Z") + expiry = expiry.replace(tzinfo=UTC) + now = datetime.now(UTC) + return (expiry - now).days + except (ValueError, OverflowError): + return None + + +def evaluate_assertions( + result: ProbeResult, + *, + expect_status: int | None = None, + expect_headers: tuple[str, ...] = (), + expect_cert_days: int | None = None, +) -> list[dict]: + """Evaluate assertions against a probe result. + + Returns a list of ``{"check": str, "passed": bool, "expected": str, + "actual": str}`` dicts. Returns an empty list when no assertions + are specified. + """ + checks: list[dict] = [] + + if expect_status is not None: + checks.append( + { + "check": "status", + "passed": result.http_code == expect_status, + "expected": str(expect_status), + "actual": str(result.http_code), + } + ) + + # Build a case-insensitive lookup of response headers + lower_headers = {k.lower(): (k, vs) for k, vs in result.response_headers.items()} + + for header_spec in expect_headers: + key, _, expected_value = header_spec.partition(":") + key = key.strip() + expected_value = expected_value.strip() + + orig_key, actual_values = lower_headers.get(key.lower(), (key, [])) + checks.append( + { + "check": f"header {key}", + "passed": expected_value in actual_values, + "expected": f"{key}: {expected_value}", + "actual": ( + f"{orig_key}: {', '.join(actual_values)}" if actual_values else "(missing)" + ), + } + ) + + if expect_cert_days is not None: + days = _parse_cert_expiry_days(result.cert_expiry) + if days is None: + checks.append( + { + "check": "cert-expiry", + "passed": False, + "expected": f">= {expect_cert_days} days", + "actual": "(no expiry date available)", + } + ) + else: + checks.append( + { + "check": "cert-expiry", + "passed": days >= expect_cert_days, + "expected": f">= {expect_cert_days} days", + "actual": f"{days} days", + } + ) + + return checks + + +def run_probe( + url: str, + *, + resolve: str | None = None, + connect_to: str | None = None, + cacert: Path | None = None, + cert_status: bool = False, + extra_headers: tuple[str, ...] = (), + timeout: int = 30, + method: str | None = None, + insecure: bool = False, + follow_redirects: bool = False, + executor: Executor | None = None, +) -> ProbeResult: + """Execute curl and return parsed probe results. + + Uses *executor* when provided (remote execution via SSH), otherwise + runs curl locally via subprocess. + + Raises: + ProxyError: On curl errors (not found, timeout, non-zero exit). + """ + cmd = build_curl_args( + url, + resolve=resolve, + connect_to=connect_to, + cacert=cacert, + cert_status=cert_status, + extra_headers=extra_headers, + timeout=timeout, + method=method, + insecure=insecure, + follow_redirects=follow_redirects, + ) + + # Give subprocess a bit more than curl's --max-time to avoid racing + subprocess_timeout = timeout + 5 + + if _is_remote(executor): + result = executor.run(cmd, timeout=subprocess_timeout) # type: ignore[union-attr] + if not result.ok: + raise ProxyError(f"curl failed (exit {result.returncode}): {result.stderr}") + return parse_curl_output(result.stdout) + + try: + proc = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=subprocess_timeout, + ) + except FileNotFoundError as e: + raise ProxyError("curl not found in PATH") from e + except subprocess.TimeoutExpired as e: + raise ProxyError("curl timed out") from e + + if proc.returncode != 0: + raise ProxyError(f"curl failed (exit {proc.returncode}): {proc.stderr}") + + return parse_curl_output(proc.stdout) diff --git a/src/rots/commands/proxy/app.py b/src/rots/commands/proxy/app.py index 6747c62..64abf3d 100644 --- a/src/rots/commands/proxy/app.py +++ b/src/rots/commands/proxy/app.py @@ -10,7 +10,9 @@ """ import contextlib +import json import logging +import time from pathlib import Path from typing import Annotated @@ -19,10 +21,12 @@ from rots import context from rots.config import Config -from ..common import DryRun +from ..common import DryRun, JsonOutput from ._helpers import ( + ProbeResult, ProxyError, adapt_to_json, + evaluate_assertions, find_free_port, parse_trace_url, patch_caddy_json, @@ -30,6 +34,7 @@ render_template, run_caddy, run_echo_server, + run_probe, validate_caddy_config, ) @@ -532,3 +537,207 @@ def trace( except ProxyError as e: raise SystemExit(str(e)) from e + + +@app.command +def probe( + url: Annotated[str, cyclopts.Parameter(help="URL to probe")], + resolve: Annotated[ + str | None, + cyclopts.Parameter(name="--resolve", help="DNS override: host:port:addr"), + ] = None, + connect_to: Annotated[ + str | None, + cyclopts.Parameter(name="--connect-to", help="Reroute: host:port:host2:port2"), + ] = None, + cacert: Annotated[ + Path | None, + cyclopts.Parameter(name="--cacert", help="CA cert for verification"), + ] = None, + cert_status: Annotated[ + bool, + cyclopts.Parameter(name="--cert-status", help="Check OCSP stapling"), + ] = False, + method: Annotated[ + str | None, + cyclopts.Parameter(name="--method", help="HTTP method (e.g., HEAD, OPTIONS)"), + ] = None, + insecure: Annotated[ + bool, + cyclopts.Parameter(name="--insecure", help="Skip TLS verification (-k)"), + ] = False, + follow_redirects: Annotated[ + bool, + cyclopts.Parameter(name="--follow", help="Follow redirects (-L)"), + ] = False, + header: Annotated[ + tuple[str, ...], + cyclopts.Parameter(name="--header", help="Extra header (repeatable)"), + ] = (), + expect_status: Annotated[ + int | None, + cyclopts.Parameter(name="--expect-status", help="Assert HTTP status"), + ] = None, + expect_header: Annotated[ + tuple[str, ...], + cyclopts.Parameter(name="--expect-header", help="Assert header (repeatable)"), + ] = (), + expect_cert_days: Annotated[ + int | None, + cyclopts.Parameter( + name="--expect-cert-days-remaining", help="Assert minimum days until cert expiry" + ), + ] = None, + json_output: JsonOutput = False, + retries: Annotated[ + int, + cyclopts.Parameter(name="--retries", help="Number of retry attempts (0 = no retries)"), + ] = 0, + retry_delay: Annotated[ + float, + cyclopts.Parameter(name="--retry-delay", help="Seconds between retries"), + ] = 1.0, +) -> None: + """Probe a live URL with curl and report TLS, headers, and timing. + + Verifies deployed behaviour — does the live endpoint return the right + TLS cert, security headers, and status code? + + Supports DNS-independent staging tests via --resolve and --connect-to. + When assertions (--expect-status, --expect-header) are provided, exits + non-zero on failure for CI use. Use --retries to wait for a service to + become ready (retries both connection errors and assertion failures). + + Examples: + rots proxy probe https://us.onetime.co/api/v2/status + rots proxy probe https://us.onetime.co/api/v2/status --expect-status 200 + rots proxy probe https://us.onetime.co/api/v2/status \\ + --resolve us.onetime.co:443:10.0.0.5 + rots proxy probe https://us.onetime.co/api/v2/status \\ + --expect-status 200 --retries 5 --retry-delay 2.0 + rots --host eu-web-01 proxy probe https://localhost:7043/health + """ + cfg = Config() + ex = cfg.get_executor(host=context.host_var.get(None)) + + try: + parse_trace_url(url) + except ProxyError as e: + raise SystemExit(str(e)) from e + + last_result: ProbeResult | None = None + last_assertions: list[dict] = [] + + for attempt in range(retries + 1): + try: + last_result = run_probe( + url, + resolve=resolve, + connect_to=connect_to, + cacert=cacert, + cert_status=cert_status, + extra_headers=header, + method=method, + insecure=insecure, + follow_redirects=follow_redirects, + executor=ex, + ) + except ProxyError as e: + if attempt < retries: + time.sleep(retry_delay) + continue + raise SystemExit(str(e)) from e + + last_assertions = evaluate_assertions( + last_result, + expect_status=expect_status, + expect_headers=expect_header, + expect_cert_days=expect_cert_days, + ) + + all_passed = not last_assertions or all(a["passed"] for a in last_assertions) + if all_passed or attempt == retries: + break + + time.sleep(retry_delay) + + assert last_result is not None # loop always sets or raises + + if json_output: + _print_probe_json(last_result, last_assertions) + else: + _print_probe_human(last_result, last_assertions) + + # Exit code: 0 if no assertions or all pass, 1 if any fail + if last_assertions and not all(a["passed"] for a in last_assertions): + raise SystemExit(1) + + +def _print_probe_human(result: ProbeResult, assertions: list[dict]) -> None: + """Print human-readable probe output.""" + print(result.url) + + # TLS section + if result.url.startswith("https"): + tag = "[ok]" if result.ssl_verify_ok else "[FAIL]" + label = ( + "verified" if result.ssl_verify_ok else (f"verify failed ({result.ssl_verify_result})") + ) + print(f"\n tls: {tag} {label}") + if result.cert_issuer: + print(f" issuer: {result.cert_issuer}") + if result.cert_expiry: + print(f" expiry: {result.cert_expiry}") + + # Status + print(f"\n status: {result.http_code}") + + # Headers + if result.response_headers: + print("\n headers:") + for k, vs in sorted(result.response_headers.items()): + for v in vs: + print(f" {k}: {v}") + + # Timing + print("\n timing:") + print(f" dns: {result.time_namelookup * 1000:7.1f} ms") + print(f" connect: {result.time_connect * 1000:7.1f} ms") + print(f" tls: {result.time_appconnect * 1000:7.1f} ms") + print(f" ttfb: {result.time_starttransfer * 1000:7.1f} ms") + print(f" total: {result.time_total * 1000:7.1f} ms") + + # Assertions + if assertions: + print() + for a in assertions: + tag = "[ok]" if a["passed"] else "[FAIL]" + if a["passed"]: + print(f" {tag} {a['check']} {a['expected']}") + else: + print(f" {tag} {a['check']}: expected {a['expected']}, got {a['actual']}") + + +def _print_probe_json(result: ProbeResult, assertions: list[dict]) -> None: + """Print JSON probe output.""" + output = { + "url": result.url, + "http_code": result.http_code, + "tls": { + "verified": result.ssl_verify_ok, + "verify_result": result.ssl_verify_result, + "issuer": result.cert_issuer, + "subject": result.cert_subject, + "expiry": result.cert_expiry, + }, + "timing": { + "dns_ms": round(result.time_namelookup * 1000, 1), + "connect_ms": round(result.time_connect * 1000, 1), + "tls_ms": round(result.time_appconnect * 1000, 1), + "ttfb_ms": round(result.time_starttransfer * 1000, 1), + "total_ms": round(result.time_total * 1000, 1), + }, + "headers": result.response_headers, + "assertions": assertions, + } + print(json.dumps(output, indent=2)) diff --git a/tests/commands/proxy/test_app.py b/tests/commands/proxy/test_app.py index f399ed9..9706251 100644 --- a/tests/commands/proxy/test_app.py +++ b/tests/commands/proxy/test_app.py @@ -897,3 +897,310 @@ def _ctx(): assert "response: 200" in captured.out assert "body: OK" in captured.out assert "forwarded request:" not in captured.out + + +class TestProbeCommand: + """Test probe command.""" + + def _make_probe_result(self, **kwargs): + """Build a ProbeResult with sensible defaults.""" + from rots.commands.proxy._helpers import ProbeResult + + defaults = { + "url": "https://example.com/api/v2/status", + "http_code": 200, + "ssl_verify_result": 0, + "ssl_verify_ok": True, + "cert_issuer": "R11", + "cert_subject": "CN=example.com", + "cert_expiry": "Aug 17 23:59:59 2026 GMT", + "http_version": "2", + "time_namelookup": 0.005, + "time_connect": 0.020, + "time_appconnect": 0.080, + "time_starttransfer": 0.150, + "time_total": 0.180, + "response_headers": { + "X-Frame-Options": ["DENY"], + "O-Via": ["B76s2"], + "Strict-Transport-Security": ["max-age=63072000"], + }, + "curl_json": {}, + } + defaults.update(kwargs) + return ProbeResult(**defaults) + + def test_human_output(self, mocker, capsys): + """Should print human-readable probe output.""" + from rots.commands.proxy.app import probe + + mocker.patch( + "rots.commands.proxy.app.run_probe", + return_value=self._make_probe_result(), + ) + + probe(url="https://example.com/api/v2/status") + + captured = capsys.readouterr() + assert "https://example.com/api/v2/status" in captured.out + assert "[ok] verified" in captured.out + assert "issuer: R11" in captured.out + assert "status: 200" in captured.out + assert "dns:" in captured.out + assert "X-Frame-Options: DENY" in captured.out + + def test_json_output(self, mocker, capsys): + """Should print JSON probe output.""" + from rots.commands.proxy.app import probe + + mocker.patch( + "rots.commands.proxy.app.run_probe", + return_value=self._make_probe_result(), + ) + + probe(url="https://example.com/api/v2/status", json_output=True) + + captured = capsys.readouterr() + output = json.loads(captured.out) + assert output["http_code"] == 200 + assert output["tls"]["verified"] is True + assert output["tls"]["issuer"] == "R11" + assert "dns_ms" in output["timing"] + assert output["headers"]["X-Frame-Options"] == ["DENY"] + + def test_assertion_pass_no_exit(self, mocker, capsys): + """Should not raise SystemExit when assertions pass.""" + from rots.commands.proxy.app import probe + + mocker.patch( + "rots.commands.proxy.app.run_probe", + return_value=self._make_probe_result(), + ) + + # Should not raise + probe( + url="https://example.com/api/v2/status", + expect_status=200, + expect_header=("O-Via: B76s2",), + ) + + captured = capsys.readouterr() + assert "[ok] status" in captured.out + assert "[ok] header O-Via" in captured.out + + def test_assertion_fail_exits_1(self, mocker): + """Should raise SystemExit(1) when assertions fail.""" + from rots.commands.proxy.app import probe + + mocker.patch( + "rots.commands.proxy.app.run_probe", + return_value=self._make_probe_result(http_code=404), + ) + + with pytest.raises(SystemExit) as exc_info: + probe(url="https://example.com/api/v2/status", expect_status=200) + + assert exc_info.value.code == 1 + + def test_proxy_error_exits(self, mocker): + """Should exit with error message on ProxyError.""" + from rots.commands.proxy._helpers import ProxyError + from rots.commands.proxy.app import probe + + mocker.patch( + "rots.commands.proxy.app.run_probe", + side_effect=ProxyError("curl not found in PATH"), + ) + + with pytest.raises(SystemExit) as exc_info: + probe(url="https://example.com/api/v2/status") + + assert "curl not found" in str(exc_info.value) + + def test_resolve_passthrough(self, mocker): + """Should pass --resolve to run_probe.""" + from rots.commands.proxy.app import probe + + mock_run = mocker.patch( + "rots.commands.proxy.app.run_probe", + return_value=self._make_probe_result(), + ) + + probe(url="https://example.com/", resolve="example.com:443:10.0.0.5") + + mock_run.assert_called_once() + call_kwargs = mock_run.call_args + assert call_kwargs[1]["resolve"] == "example.com:443:10.0.0.5" + + def test_expect_header_evaluation(self, mocker, capsys): + """Should evaluate header assertions and show results.""" + from rots.commands.proxy.app import probe + + mocker.patch( + "rots.commands.proxy.app.run_probe", + return_value=self._make_probe_result(), + ) + + # X-Frame-Options: DENY should pass, X-Missing: value should fail + with pytest.raises(SystemExit) as exc_info: + probe( + url="https://example.com/api/v2/status", + expect_header=("X-Frame-Options: DENY", "X-Missing: value"), + ) + + assert exc_info.value.code == 1 + captured = capsys.readouterr() + assert "[ok] header X-Frame-Options" in captured.out + assert "[FAIL] header X-Missing" in captured.out + + def test_method_passthrough(self, mocker): + """Should pass --method to run_probe.""" + from rots.commands.proxy.app import probe + + mock_run = mocker.patch( + "rots.commands.proxy.app.run_probe", + return_value=self._make_probe_result(), + ) + + probe(url="https://example.com/", method="HEAD") + + mock_run.assert_called_once() + assert mock_run.call_args[1]["method"] == "HEAD" + + def test_insecure_passthrough(self, mocker): + """Should pass --insecure to run_probe.""" + from rots.commands.proxy.app import probe + + mock_run = mocker.patch( + "rots.commands.proxy.app.run_probe", + return_value=self._make_probe_result(), + ) + + probe(url="https://example.com/", insecure=True) + + mock_run.assert_called_once() + assert mock_run.call_args[1]["insecure"] is True + + def test_follow_passthrough(self, mocker): + """Should pass --follow to run_probe.""" + from rots.commands.proxy.app import probe + + mock_run = mocker.patch( + "rots.commands.proxy.app.run_probe", + return_value=self._make_probe_result(), + ) + + probe(url="https://example.com/", follow_redirects=True) + + mock_run.assert_called_once() + assert mock_run.call_args[1]["follow_redirects"] is True + + def test_retry_succeeds_after_probe_failure(self, mocker, capsys): + """Should succeed when run_probe fails first then succeeds on retry.""" + from rots.commands.proxy._helpers import ProxyError + from rots.commands.proxy.app import probe + + mock_run = mocker.patch( + "rots.commands.proxy.app.run_probe", + side_effect=[ + ProxyError("connection refused"), + self._make_probe_result(), + ], + ) + mocker.patch("rots.commands.proxy.app.time.sleep") + + # Should not raise + probe(url="https://example.com/api/v2/status", retries=1) + + assert mock_run.call_count == 2 + captured = capsys.readouterr() + assert "status: 200" in captured.out + + def test_retry_exhausted_exits(self, mocker): + """Should raise SystemExit when all retry attempts fail.""" + from rots.commands.proxy._helpers import ProxyError + from rots.commands.proxy.app import probe + + mocker.patch( + "rots.commands.proxy.app.run_probe", + side_effect=ProxyError("connection refused"), + ) + mocker.patch("rots.commands.proxy.app.time.sleep") + + with pytest.raises(SystemExit) as exc_info: + probe(url="https://example.com/api/v2/status", retries=2) + + assert "connection refused" in str(exc_info.value) + + def test_retry_assertion_failure_then_pass(self, mocker, capsys): + """Should retry when assertions fail and succeed on next attempt.""" + from rots.commands.proxy.app import probe + + mock_run = mocker.patch( + "rots.commands.proxy.app.run_probe", + side_effect=[ + self._make_probe_result(http_code=503), + self._make_probe_result(http_code=200), + ], + ) + mocker.patch("rots.commands.proxy.app.time.sleep") + + # Should not raise — second attempt returns 200 + probe( + url="https://example.com/api/v2/status", + expect_status=200, + retries=1, + ) + + assert mock_run.call_count == 2 + captured = capsys.readouterr() + assert "[ok] status" in captured.out + + def test_retry_delay_called(self, mocker): + """Should sleep with the configured delay between retries.""" + from rots.commands.proxy._helpers import ProxyError + from rots.commands.proxy.app import probe + + mocker.patch( + "rots.commands.proxy.app.run_probe", + side_effect=[ + ProxyError("timeout"), + self._make_probe_result(), + ], + ) + mock_sleep = mocker.patch("rots.commands.proxy.app.time.sleep") + + probe(url="https://example.com/", retries=1, retry_delay=2.5) + + mock_sleep.assert_called_once_with(2.5) + + def test_no_retry_default(self, mocker, capsys): + """With retries=0 (default), run_probe should be called exactly once.""" + from rots.commands.proxy.app import probe + + mock_run = mocker.patch( + "rots.commands.proxy.app.run_probe", + return_value=self._make_probe_result(), + ) + + probe(url="https://example.com/api/v2/status") + + mock_run.assert_called_once() + + def test_cert_days_passthrough(self, mocker, capsys): + """Should pass expect_cert_days to evaluate_assertions.""" + from rots.commands.proxy.app import probe + + mocker.patch( + "rots.commands.proxy.app.run_probe", + return_value=self._make_probe_result(), + ) + mock_eval = mocker.patch( + "rots.commands.proxy.app.evaluate_assertions", + return_value=[], + ) + + probe(url="https://example.com/api/v2/status", expect_cert_days=30) + + mock_eval.assert_called_once() + assert mock_eval.call_args[1]["expect_cert_days"] == 30 diff --git a/tests/commands/proxy/test_helpers.py b/tests/commands/proxy/test_helpers.py index 2460acd..2769b2f 100644 --- a/tests/commands/proxy/test_helpers.py +++ b/tests/commands/proxy/test_helpers.py @@ -5,6 +5,7 @@ import socket import subprocess import urllib.request +from datetime import UTC import pytest @@ -834,3 +835,415 @@ def test_shuts_down_after_context_exit(self): # After exiting, port should be unbound (server shut down) with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("127.0.0.1", port)) # should not raise + + +class TestBuildCurlArgs: + """Test build_curl_args function.""" + + def test_minimal_url(self): + """Should produce base curl command with just the URL.""" + from rots.commands.proxy._helpers import build_curl_args + + args = build_curl_args("https://example.com") + assert args[0] == "curl" + assert "-sS" in args + assert "-D" in args + assert args[-2] == "--" + assert args[-1] == "https://example.com" + assert "--max-time" in args + idx = args.index("--max-time") + assert args[idx + 1] == "30" + + def test_resolve_flag(self): + """Should add --resolve when specified.""" + from rots.commands.proxy._helpers import build_curl_args + + args = build_curl_args("https://example.com", resolve="example.com:443:10.0.0.5") + assert "--resolve" in args + idx = args.index("--resolve") + assert args[idx + 1] == "example.com:443:10.0.0.5" + + def test_connect_to_flag(self): + """Should add --connect-to when specified.""" + from rots.commands.proxy._helpers import build_curl_args + + args = build_curl_args("https://example.com", connect_to="example.com:443:staging:443") + assert "--connect-to" in args + idx = args.index("--connect-to") + assert args[idx + 1] == "example.com:443:staging:443" + + def test_cacert_flag(self): + """Should add --cacert with path.""" + from pathlib import Path + + from rots.commands.proxy._helpers import build_curl_args + + args = build_curl_args("https://example.com", cacert=Path("/tmp/ca.pem")) + assert "--cacert" in args + idx = args.index("--cacert") + assert args[idx + 1] == "/tmp/ca.pem" + + def test_cert_status_flag(self): + """Should add --cert-status when True.""" + from rots.commands.proxy._helpers import build_curl_args + + args = build_curl_args("https://example.com", cert_status=True) + assert "--cert-status" in args + + def test_extra_headers(self): + """Should add -H for each extra header.""" + from rots.commands.proxy._helpers import build_curl_args + + args = build_curl_args( + "https://example.com", + extra_headers=("Origin: https://example.com", "X-Custom: test"), + ) + h_indices = [i for i, a in enumerate(args) if a == "-H"] + assert len(h_indices) == 2 + assert args[h_indices[0] + 1] == "Origin: https://example.com" + assert args[h_indices[1] + 1] == "X-Custom: test" + + def test_method_flag(self): + """Should add -X when method is specified.""" + from rots.commands.proxy._helpers import build_curl_args + + args = build_curl_args("https://example.com", method="HEAD") + assert "-X" in args + idx = args.index("-X") + assert args[idx + 1] == "HEAD" + + def test_method_default_omits_x(self): + """Should not include -X when method is None.""" + from rots.commands.proxy._helpers import build_curl_args + + args = build_curl_args("https://example.com") + assert "-X" not in args + + def test_insecure_flag(self): + """Should add -k when insecure is True.""" + from rots.commands.proxy._helpers import build_curl_args + + args = build_curl_args("https://example.com", insecure=True) + assert "-k" in args + + def test_insecure_false_omits_k(self): + """Should not include -k when insecure is False.""" + from rots.commands.proxy._helpers import build_curl_args + + args = build_curl_args("https://example.com") + assert "-k" not in args + + def test_follow_flag(self): + """Should add -L when follow_redirects is True.""" + from rots.commands.proxy._helpers import build_curl_args + + args = build_curl_args("https://example.com", follow_redirects=True) + assert "-L" in args + + def test_follow_default_no_L(self): + """Should not include -L when follow_redirects is False.""" + from rots.commands.proxy._helpers import build_curl_args + + args = build_curl_args("https://example.com") + assert "-L" not in args + + +class TestParseCurlOutput: + """Test parse_curl_output function.""" + + def _make_output(self, headers: str = "", curl_json: dict | None = None) -> str: + """Build fake curl stdout with sentinel-separated sections.""" + if curl_json is None: + curl_json = { + "http_code": 200, + "ssl_verify_result": 0, + "http_version": "2", + "url_effective": "https://example.com/", + "time_namelookup": 0.005, + "time_connect": 0.020, + "time_appconnect": 0.080, + "time_starttransfer": 0.150, + "time_total": 0.180, + "certs": ( + "Issuer: R11\nSubject: CN=example.com\nExpire date: Aug 17 23:59:59 2026 GMT\n" + ), + } + if not headers: + headers = "HTTP/2 200\r\ncontent-type: text/html\r\nx-frame-options: DENY\r\n" + return f"{headers}\n%%CURL_JSON%%\n{json.dumps(curl_json)}" + + def test_parses_headers_and_json(self): + """Should parse response headers and curl JSON blob.""" + from rots.commands.proxy._helpers import parse_curl_output + + result = parse_curl_output(self._make_output()) + assert result.http_code == 200 + assert result.ssl_verify_ok is True + assert result.cert_issuer == "R11" + assert result.cert_subject == "CN=example.com" + assert result.cert_expiry == "Aug 17 23:59:59 2026 GMT" + assert result.response_headers["x-frame-options"] == ["DENY"] + assert result.time_total == pytest.approx(0.180) + + def test_missing_sentinel_raises(self): + """Should raise ProxyError when sentinel is missing.""" + from rots.commands.proxy._helpers import ProxyError, parse_curl_output + + with pytest.raises(ProxyError, match="missing sentinel"): + parse_curl_output("just some text without sentinel") + + def test_malformed_json_raises(self): + """Should raise ProxyError when JSON section is malformed.""" + from rots.commands.proxy._helpers import ProxyError, parse_curl_output + + bad_output = "HTTP/2 200\r\n\n%%CURL_JSON%%\n{not valid json" + with pytest.raises(ProxyError, match="JSON output malformed"): + parse_curl_output(bad_output) + + +class TestParseCertExpiryDays: + """Test _parse_cert_expiry_days helper.""" + + def test_valid_future_date(self): + """Should return positive days for future cert.""" + from datetime import datetime, timedelta + + from rots.commands.proxy._helpers import _parse_cert_expiry_days + + future = datetime.now(UTC) + timedelta(days=90) + date_str = future.strftime("%b %d %H:%M:%S %Y GMT") + result = _parse_cert_expiry_days(date_str) + assert result is not None + assert 89 <= result <= 90 + + def test_expired_cert(self): + """Should return negative days for expired cert.""" + from rots.commands.proxy._helpers import _parse_cert_expiry_days + + result = _parse_cert_expiry_days("Jan 01 00:00:00 2020 GMT") + assert result is not None + assert result < 0 + + def test_empty_string_returns_none(self): + """Should return None for empty cert expiry.""" + from rots.commands.proxy._helpers import _parse_cert_expiry_days + + assert _parse_cert_expiry_days("") is None + + def test_malformed_date_returns_none(self): + """Should return None for unparseable date.""" + from rots.commands.proxy._helpers import _parse_cert_expiry_days + + assert _parse_cert_expiry_days("not a date") is None + + +class TestEvaluateAssertions: + """Test evaluate_assertions function.""" + + def _make_result(self, **kwargs): + """Build a ProbeResult with sensible defaults.""" + from rots.commands.proxy._helpers import ProbeResult + + defaults = { + "url": "https://example.com/", + "http_code": 200, + "ssl_verify_result": 0, + "ssl_verify_ok": True, + "cert_issuer": "R11", + "cert_subject": "CN=example.com", + "cert_expiry": "Aug 17 23:59:59 2026 GMT", + "http_version": "2", + "time_namelookup": 0.005, + "time_connect": 0.020, + "time_appconnect": 0.080, + "time_starttransfer": 0.150, + "time_total": 0.180, + "response_headers": { + "X-Frame-Options": ["DENY"], + "O-Via": ["B76s2"], + "Strict-Transport-Security": ["max-age=63072000"], + }, + "curl_json": {}, + } + defaults.update(kwargs) + return ProbeResult(**defaults) + + def test_status_pass(self): + """Should pass when status matches.""" + from rots.commands.proxy._helpers import evaluate_assertions + + result = self._make_result(http_code=200) + checks = evaluate_assertions(result, expect_status=200) + assert len(checks) == 1 + assert checks[0]["passed"] is True + + def test_status_fail(self): + """Should fail when status does not match.""" + from rots.commands.proxy._helpers import evaluate_assertions + + result = self._make_result(http_code=404) + checks = evaluate_assertions(result, expect_status=200) + assert len(checks) == 1 + assert checks[0]["passed"] is False + assert checks[0]["actual"] == "404" + + def test_header_pass(self): + """Should pass when header value matches.""" + from rots.commands.proxy._helpers import evaluate_assertions + + result = self._make_result() + checks = evaluate_assertions(result, expect_headers=("O-Via: B76s2",)) + assert len(checks) == 1 + assert checks[0]["passed"] is True + + def test_header_fail(self): + """Should fail when header value does not match.""" + from rots.commands.proxy._helpers import evaluate_assertions + + result = self._make_result() + checks = evaluate_assertions(result, expect_headers=("O-Via: wrong",)) + assert len(checks) == 1 + assert checks[0]["passed"] is False + + def test_header_missing(self): + """Should fail when expected header is missing.""" + from rots.commands.proxy._helpers import evaluate_assertions + + result = self._make_result() + checks = evaluate_assertions(result, expect_headers=("X-Missing: value",)) + assert len(checks) == 1 + assert checks[0]["passed"] is False + assert checks[0]["actual"] == "(missing)" + + def test_header_case_insensitive(self): + """Should match header keys case-insensitively.""" + from rots.commands.proxy._helpers import evaluate_assertions + + result = self._make_result() + checks = evaluate_assertions(result, expect_headers=("x-frame-options: DENY",)) + assert len(checks) == 1 + assert checks[0]["passed"] is True + + def test_empty_assertions(self): + """Should return empty list when no assertions specified.""" + from rots.commands.proxy._helpers import evaluate_assertions + + result = self._make_result() + checks = evaluate_assertions(result) + assert checks == [] + + def test_cert_days_pass(self): + """Should pass when cert has enough days remaining.""" + from datetime import datetime, timedelta + + from rots.commands.proxy._helpers import evaluate_assertions + + future = datetime.now(UTC) + timedelta(days=90) + cert_expiry = future.strftime("%b %d %H:%M:%S %Y GMT") + result = self._make_result(cert_expiry=cert_expiry) + checks = evaluate_assertions(result, expect_cert_days=30) + cert_check = [c for c in checks if c["check"] == "cert-expiry"] + assert len(cert_check) == 1 + assert cert_check[0]["passed"] is True + assert cert_check[0]["expected"] == ">= 30 days" + + def test_cert_days_fail(self): + """Should fail when cert has fewer days than threshold.""" + from datetime import datetime, timedelta + + from rots.commands.proxy._helpers import evaluate_assertions + + future = datetime.now(UTC) + timedelta(days=15) + cert_expiry = future.strftime("%b %d %H:%M:%S %Y GMT") + result = self._make_result(cert_expiry=cert_expiry) + checks = evaluate_assertions(result, expect_cert_days=30) + cert_check = [c for c in checks if c["check"] == "cert-expiry"] + assert len(cert_check) == 1 + assert cert_check[0]["passed"] is False + assert "days" in cert_check[0]["actual"] + + def test_cert_days_empty_expiry(self): + """Should fail gracefully when cert_expiry is empty.""" + from rots.commands.proxy._helpers import evaluate_assertions + + result = self._make_result(cert_expiry="") + checks = evaluate_assertions(result, expect_cert_days=30) + cert_check = [c for c in checks if c["check"] == "cert-expiry"] + assert len(cert_check) == 1 + assert cert_check[0]["passed"] is False + assert cert_check[0]["actual"] == "(no expiry date available)" + + +class TestRunProbe: + """Test run_probe function.""" + + def _make_curl_stdout(self) -> str: + """Build realistic curl output for mock subprocess.""" + headers = "HTTP/2 200\r\ncontent-type: text/html\r\n" + curl_json = { + "http_code": 200, + "ssl_verify_result": 0, + "http_version": "2", + "url_effective": "https://example.com/", + "time_namelookup": 0.005, + "time_connect": 0.020, + "time_appconnect": 0.080, + "time_starttransfer": 0.150, + "time_total": 0.180, + "certs": "", + } + return f"{headers}\n%%CURL_JSON%%\n{json.dumps(curl_json)}" + + def test_success(self, mocker): + """Should return ProbeResult on successful curl execution.""" + from rots.commands.proxy._helpers import run_probe + + mocker.patch( + "subprocess.run", + return_value=mocker.Mock( + returncode=0, + stdout=self._make_curl_stdout(), + stderr="", + ), + ) + + result = run_probe("https://example.com") + assert result.http_code == 200 + assert result.url == "https://example.com/" + + def test_curl_not_found(self, mocker): + """Should raise ProxyError when curl is not installed.""" + from rots.commands.proxy._helpers import ProxyError, run_probe + + mocker.patch("subprocess.run", side_effect=FileNotFoundError("curl not found")) + + with pytest.raises(ProxyError, match="curl not found"): + run_probe("https://example.com") + + def test_curl_failure(self, mocker): + """Should raise ProxyError on non-zero curl exit.""" + from rots.commands.proxy._helpers import ProxyError, run_probe + + mocker.patch( + "subprocess.run", + return_value=mocker.Mock( + returncode=7, + stdout="", + stderr="Failed to connect", + ), + ) + + with pytest.raises(ProxyError, match="curl failed.*exit 7"): + run_probe("https://example.com") + + def test_curl_timeout(self, mocker): + """Should raise ProxyError when curl times out.""" + from rots.commands.proxy._helpers import ProxyError, run_probe + + mocker.patch( + "subprocess.run", + side_effect=subprocess.TimeoutExpired(cmd="curl", timeout=35), + ) + + with pytest.raises(ProxyError, match="curl timed out"): + run_probe("https://example.com")