Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/rots/commands/proxy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

"""Proxy management commands for OTS containers."""

from .app import app, diff, reload, render, trace
from .app import app, diff, probe, reload, render, trace

__all__ = [
"app",
"diff",
"probe",
"reload",
"render",
"trace",
Expand Down
293 changes: 293 additions & 0 deletions src/rots/commands/proxy/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@

import contextlib
import copy
import dataclasses
import json
import socket
import subprocess
import tempfile
import threading
import time
import urllib.parse
from datetime import UTC
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from typing import TYPE_CHECKING
Expand All @@ -37,6 +39,27 @@ class ProxyError(Exception):
"""Error during proxy configuration."""


@dataclasses.dataclass
class ProbeResult:
"""Result of probing a URL with curl."""

url: str
http_code: int
ssl_verify_result: int # 0 = valid chain
ssl_verify_ok: bool
cert_issuer: str
cert_subject: str
cert_expiry: str
http_version: str
time_namelookup: float # seconds
time_connect: float
time_appconnect: float # TLS handshake complete
time_starttransfer: float # TTFB
time_total: float
response_headers: dict[str, str]
curl_json: dict # raw write-out for --json passthrough


def parse_trace_url(url: str) -> urllib.parse.ParseResult:
"""Normalise and validate a URL for ``proxy trace``.

Expand Down Expand Up @@ -493,3 +516,273 @@ def _read_stderr() -> str:
except subprocess.TimeoutExpired:
proc.kill()
proc.wait(timeout=3)


# ---------------------------------------------------------------------------
# Probe helpers
# ---------------------------------------------------------------------------

_CURL_SENTINEL = "%%CURL_JSON%%"


def build_curl_args(
url: str,
*,
resolve: str | None = None,
connect_to: str | None = None,
cacert: Path | None = None,
cert_status: bool = False,
extra_headers: tuple[str, ...] = (),
timeout: int = 30,
method: str | None = None,
insecure: bool = False,
follow_redirects: bool = False,
) -> list[str]:
"""Build the curl command list for probing *url*.

Returns the full argv list without executing anything — purely
testable by asserting on the returned list.
"""
cmd = [
"curl",
"-s",
"-o",
"/dev/null",
"-D",
"-",
"-w",
f"\n{_CURL_SENTINEL}\n%{{json}}",
"--max-time",
str(timeout),
Comment on lines +546 to +556
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

build_curl_args uses -s (silent), which suppresses curl's error messages. Since run_probe relies on proc.stderr/result.stderr to explain failures, many real-world failures will surface as curl failed ...: with an empty message. Consider switching to -sS (or -s + -S) so errors remain available while keeping progress output quiet.

Copilot uses AI. Check for mistakes.
]

if resolve is not None:
cmd.extend(["--resolve", resolve])
if connect_to is not None:
cmd.extend(["--connect-to", connect_to])
if cacert is not None:
cmd.extend(["--cacert", str(cacert)])
if cert_status:
cmd.append("--cert-status")
for h in extra_headers:
cmd.extend(["-H", h])
if method is not None:
cmd.extend(["-X", method])
if insecure:
cmd.append("-k")
if follow_redirects:
cmd.append("-L")

cmd.append(url)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The build_curl_args function is vulnerable to argument injection. The user-supplied url is appended directly to the cmd list. If the url starts with a hyphen (e.g., -K/etc/passwd), curl will interpret it as an option. This allows an attacker who can control the url parameter to inject arbitrary curl options, which could lead to sensitive file disclosure (using -K or --config) or arbitrary file write (using -o or --output).

To remediate this, use the -- separator to signal the end of options before appending the URL, or validate the URL to ensure it has a valid scheme and hostname.

Suggested change
cmd.append(url)
cmd.append("--")
cmd.append(url)

return cmd


def parse_curl_output(stdout: str) -> ProbeResult:
"""Parse combined curl output (``-D -`` headers + ``-w '%{json}'``).

The output is split on the ``%%CURL_JSON%%`` sentinel. The first
part contains HTTP response headers; the second is the JSON blob
from curl's ``--write-out '%{json}'``.

Raises:
ProxyError: When the sentinel is missing or JSON is malformed.
"""
if _CURL_SENTINEL not in stdout:
raise ProxyError("curl output missing sentinel — unexpected format")

header_section, json_section = stdout.split(_CURL_SENTINEL, 1)

# Parse response headers (skip status line)
response_headers: dict[str, str] = {}
for line in header_section.strip().splitlines():
if ":" in line and not line.startswith("HTTP/"):
key, _, value = line.partition(":")
response_headers[key.strip()] = value.strip()

try:
curl_json = json.loads(json_section.strip())
except (json.JSONDecodeError, ValueError) as e:
raise ProxyError(f"curl JSON output malformed: {e}") from e

# Extract cert details from the certs string
certs_str = curl_json.get("certs", "")
cert_issuer = ""
cert_subject = ""
cert_expiry = ""
for cert_line in certs_str.splitlines():
stripped = cert_line.strip()
if stripped.startswith("Issuer:") and not cert_issuer:
cert_issuer = stripped[len("Issuer:") :].strip()
elif stripped.startswith("Subject:") and not cert_subject:
cert_subject = stripped[len("Subject:") :].strip()
elif stripped.startswith("Expire date:") and not cert_expiry:
cert_expiry = stripped[len("Expire date:") :].strip()

ssl_verify = curl_json.get("ssl_verify_result", -1)
return ProbeResult(
url=curl_json.get("url_effective", curl_json.get("url", "")),
http_code=curl_json.get("http_code", 0),
ssl_verify_result=ssl_verify,
ssl_verify_ok=ssl_verify == 0,
cert_issuer=cert_issuer,
cert_subject=cert_subject,
cert_expiry=cert_expiry,
http_version=curl_json.get("http_version", ""),
time_namelookup=curl_json.get("time_namelookup", 0.0),
time_connect=curl_json.get("time_connect", 0.0),
time_appconnect=curl_json.get("time_appconnect", 0.0),
time_starttransfer=curl_json.get("time_starttransfer", 0.0),
time_total=curl_json.get("time_total", 0.0),
response_headers=response_headers,
curl_json=curl_json,
)


def _parse_cert_expiry_days(cert_expiry: str) -> int | None:
"""Parse cert expiry string and return days remaining.

The format comes from curl's ``%{json}`` output, e.g.,
``"Aug 17 23:59:59 2026 GMT"``.

Returns None on empty string or parse failure.
"""
if not cert_expiry:
return None
try:
from datetime import datetime

expiry = datetime.strptime(cert_expiry, "%b %d %H:%M:%S %Y %Z")
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_parse_cert_expiry_days uses datetime.strptime(..., "%b %d %H:%M:%S %Y %Z") to parse the timezone. %Z parsing is platform/locale-dependent and may not reliably accept GMT, causing valid expiries to be treated as unparseable and cert assertions to fail. Since you replace tzinfo with UTC anyway, consider parsing with a literal " GMT" suffix (or using a more robust parser) and then attaching UTC.

Suggested change
expiry = datetime.strptime(cert_expiry, "%b %d %H:%M:%S %Y %Z")
# Normalize and strip the fixed " GMT" suffix instead of relying on
# platform-dependent %Z parsing. We then explicitly attach UTC.
normalized = cert_expiry.strip()
if normalized.endswith(" GMT"):
normalized = normalized[: -len(" GMT")]
expiry = datetime.strptime(normalized, "%b %d %H:%M:%S %Y")

Copilot uses AI. Check for mistakes.
expiry = expiry.replace(tzinfo=UTC)
now = datetime.now(UTC)
return (expiry - now).days
except (ValueError, OverflowError):
return None


def evaluate_assertions(
result: ProbeResult,
*,
expect_status: int | None = None,
expect_headers: tuple[str, ...] = (),
expect_cert_days: int | None = None,
) -> list[dict]:
"""Evaluate assertions against a probe result.

Returns a list of ``{"check": str, "passed": bool, "expected": str,
"actual": str}`` dicts. Returns an empty list when no assertions
are specified.
"""
checks: list[dict] = []

if expect_status is not None:
checks.append(
{
"check": "status",
"passed": result.http_code == expect_status,
"expected": str(expect_status),
"actual": str(result.http_code),
}
)

# Build a case-insensitive lookup of response headers
lower_headers = {k.lower(): (k, v) for k, v in result.response_headers.items()}

for header_spec in expect_headers:
key, _, expected_value = header_spec.partition(":")
key = key.strip()
expected_value = expected_value.strip()

orig_key, actual_value = lower_headers.get(key.lower(), (key, ""))
checks.append(
{
"check": f"header {key}",
"passed": actual_value == expected_value,
"expected": f"{key}: {expected_value}",
"actual": f"{orig_key}: {actual_value}" if actual_value else "(missing)",
}
)

if expect_cert_days is not None:
days = _parse_cert_expiry_days(result.cert_expiry)
if days is None:
checks.append(
{
"check": "cert-expiry",
"passed": False,
"expected": f">= {expect_cert_days} days",
"actual": "(no expiry date available)",
}
)
else:
checks.append(
{
"check": "cert-expiry",
"passed": days >= expect_cert_days,
"expected": f">= {expect_cert_days} days",
"actual": f"{days} days",
}
)

return checks


def run_probe(
url: str,
*,
resolve: str | None = None,
connect_to: str | None = None,
cacert: Path | None = None,
cert_status: bool = False,
extra_headers: tuple[str, ...] = (),
timeout: int = 30,
method: str | None = None,
insecure: bool = False,
follow_redirects: bool = False,
executor: Executor | None = None,
) -> ProbeResult:
"""Execute curl and return parsed probe results.

Uses *executor* when provided (remote execution via SSH), otherwise
runs curl locally via subprocess.

Raises:
ProxyError: On curl errors (not found, timeout, non-zero exit).
"""
cmd = build_curl_args(
url,
resolve=resolve,
connect_to=connect_to,
cacert=cacert,
cert_status=cert_status,
extra_headers=extra_headers,
timeout=timeout,
method=method,
insecure=insecure,
follow_redirects=follow_redirects,
)

# Give subprocess a bit more than curl's --max-time to avoid racing
subprocess_timeout = timeout + 5

if _is_remote(executor):
result = executor.run(cmd, timeout=subprocess_timeout) # type: ignore[union-attr]
if not result.ok:
raise ProxyError(f"curl failed (exit {result.returncode}): {result.stderr}")
return parse_curl_output(result.stdout)

try:
proc = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=subprocess_timeout,
)
except FileNotFoundError as e:
raise ProxyError("curl not found in PATH") from e
except subprocess.TimeoutExpired as e:
raise ProxyError("curl timed out") from e

if proc.returncode != 0:
raise ProxyError(f"curl failed (exit {proc.returncode}): {proc.stderr}")

return parse_curl_output(proc.stdout)
Loading