diff --git a/src/stac_auth_proxy/__main__.py b/src/stac_auth_proxy/__main__.py index 42d97b5..db33879 100644 --- a/src/stac_auth_proxy/__main__.py +++ b/src/stac_auth_proxy/__main__.py @@ -3,16 +3,20 @@ import uvicorn from uvicorn.config import LOGGING_CONFIG -LOGGING_CONFIG["loggers"][__package__] = { - "level": "DEBUG", - "handlers": ["default"], -} - uvicorn.run( f"{__package__}.app:create_app", host="0.0.0.0", port=8000, - log_config=LOGGING_CONFIG, + log_config={ + **LOGGING_CONFIG, + "loggers": { + **LOGGING_CONFIG["loggers"], + __package__: { + "level": "DEBUG", + "handlers": ["default"], + }, + }, + }, reload=True, factory=True, ) diff --git a/src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py b/src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py index ac385bd..54fccd4 100644 --- a/src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py +++ b/src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py @@ -4,13 +4,14 @@ import re from dataclasses import dataclass from typing import Any, Optional -from urllib.parse import urlparse, urlunparse +from urllib.parse import ParseResult, urlparse, urlunparse from starlette.datastructures import Headers from starlette.requests import Request from starlette.types import ASGIApp, Scope from ..utils.middleware import JsonResponseMiddleware +from ..utils.requests import get_base_url from ..utils.stac import get_links logger = logging.getLogger(__name__) @@ -40,37 +41,81 @@ def should_transform_response(self, request: Request, scope: Scope) -> bool: def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]: """Update links in the response to include root_path.""" - for link in get_links(data): - href = link.get("href") - if not href: - continue + # Get the client's actual base URL (accounting for load balancers/proxies) + req_base_url = get_base_url(request) + parsed_req_url = urlparse(req_base_url) + parsed_upstream_url = urlparse(self.upstream_url) + for link in get_links(data): try: - parsed_link = urlparse(href) - - # Ignore links that are not for this proxy - if parsed_link.netloc != request.headers.get("host"): - continue - - # Remove the upstream_url path from the link if it exists - parsed_upstream_url = urlparse(self.upstream_url) - if parsed_upstream_url.path != "/" and parsed_link.path.startswith( - parsed_upstream_url.path - ): - parsed_link = parsed_link._replace( - path=parsed_link.path[len(parsed_upstream_url.path) :] - ) - - # Add the root_path to the link if it exists - if self.root_path: - parsed_link = parsed_link._replace( - path=f"{self.root_path}{parsed_link.path}" - ) - - link["href"] = urlunparse(parsed_link) + self._update_link(link, parsed_req_url, parsed_upstream_url) except Exception as e: logger.error( - "Failed to parse link href %r, (ignoring): %s", href, str(e) + "Failed to parse link href %r, (ignoring): %s", + link.get("href"), + str(e), ) - return data + + def _update_link( + self, link: dict[str, Any], request_url: ParseResult, upstream_url: ParseResult + ) -> None: + """ + Ensure that link hrefs that are local to upstream url are rewritten as local to + the proxy. + """ + if "href" not in link: + logger.warning("Link %r has no href", link) + return + + parsed_link = urlparse(link["href"]) + + if parsed_link.netloc not in [ + request_url.netloc, + upstream_url.netloc, + ]: + logger.debug( + "Ignoring link %s because it is not for an endpoint behind this proxy (%s or %s)", + link["href"], + request_url.netloc, + upstream_url.netloc, + ) + return + + # If the link path is not a descendant of the upstream path, don't transform it + if upstream_url.path != "/" and not parsed_link.path.startswith( + upstream_url.path + ): + logger.debug( + "Ignoring link %s because it is not descendant of upstream path (%s)", + link["href"], + upstream_url.path, + ) + return + + # Replace the upstream host with the client's host + if parsed_link.netloc == upstream_url.netloc: + parsed_link = parsed_link._replace(netloc=request_url.netloc)._replace( + scheme=request_url.scheme + ) + + # Rewrite the link path + if upstream_url.path != "/" and parsed_link.path.startswith(upstream_url.path): + parsed_link = parsed_link._replace( + path=parsed_link.path[len(upstream_url.path) :] + ) + + # Add the root_path to the link if it exists + if self.root_path: + parsed_link = parsed_link._replace( + path=f"{self.root_path}{parsed_link.path}" + ) + + logger.debug( + "Rewriting %r link %r to %r", + link.get("rel"), + link["href"], + urlunparse(parsed_link), + ) + + link["href"] = urlunparse(parsed_link) diff --git a/src/stac_auth_proxy/utils/requests.py b/src/stac_auth_proxy/utils/requests.py index 5dea88e..7b8a53a 100644 --- a/src/stac_auth_proxy/utils/requests.py +++ b/src/stac_auth_proxy/utils/requests.py @@ -1,13 +1,18 @@ """Utility functions for working with HTTP requests.""" import json +import logging import re from dataclasses import dataclass, field -from typing import Optional, Sequence +from typing import Dict, Optional, Sequence from urllib.parse import urlparse +from starlette.requests import Request + from ..config import EndpointMethods +logger = logging.getLogger(__name__) + def extract_variables(url: str) -> dict: """ @@ -90,3 +95,110 @@ def build_server_timing_header( if current_value: return f"{current_value}, {metric}" return metric + + +def parse_forwarded_header(forwarded_header: str) -> Dict[str, str]: + """ + Parse the Forwarded header according to RFC 7239. + + Args: + forwarded_header: The Forwarded header value + + Returns: + Dictionary containing parsed forwarded information (proto, host, for, by, etc.) + + Example: + >>> parse_forwarded_header("for=192.0.2.43; by=203.0.113.60; proto=https; host=api.example.com") + {'for': '192.0.2.43', 'by': '203.0.113.60', 'proto': 'https', 'host': 'api.example.com'} + + """ + # Forwarded header format: "for=192.0.2.43, for=198.51.100.17; by=203.0.113.60; proto=https; host=example.com" + # The format is: for=value1, for=value2; by=value; proto=value; host=value + # We need to parse all the key=value pairs, taking the first 'for' value + forwarded_info = {} + + try: + # Parse all key=value pairs separated by semicolons + for pair in forwarded_header.split(";"): + pair = pair.strip() + if "=" in pair: + key, value = pair.split("=", 1) + key = key.strip() + value = value.strip().strip('"') + + # For 'for' field, only take the first value if there are multiple + if key == "for" and key not in forwarded_info: + # Extract the first for value (before comma if present) + first_for_value = value.split(",")[0].strip() + forwarded_info[key] = first_for_value + elif key != "for": + # For other fields, just use the value as-is + forwarded_info[key] = value + except Exception as e: + logger.warning(f"Failed to parse Forwarded header '{forwarded_header}': {e}") + return {} + + return forwarded_info + + +def get_base_url(request: Request) -> str: + """ + Get the request's base URL, accounting for forwarded headers from load balancers/proxies. + + This function handles both the standard Forwarded header (RFC 7239) and legacy + X-Forwarded-* headers to reconstruct the original client URL when the service + is deployed behind load balancers or reverse proxies. + + Args: + request: The Starlette request object + + Returns: + The reconstructed client base URL + + Example: + >>> # With Forwarded header + >>> request.headers = {"Forwarded": "for=192.0.2.43; proto=https; host=api.example.com"} + >>> get_base_url(request) + "https://api.example.com/" + + >>> # With X-Forwarded-* headers + >>> request.headers = {"X-Forwarded-Host": "api.example.com", "X-Forwarded-Proto": "https"} + >>> get_base_url(request) + "https://api.example.com/" + + """ + # Check for standard Forwarded header first (RFC 7239) + forwarded_header = request.headers.get("Forwarded") + if forwarded_header: + try: + forwarded_info = parse_forwarded_header(forwarded_header) + # Only use Forwarded header if we successfully parsed it and got useful info + if forwarded_info and ( + "proto" in forwarded_info or "host" in forwarded_info + ): + scheme = forwarded_info.get("proto", request.url.scheme) + host = forwarded_info.get("host", request.url.netloc) + # Note: Forwarded header doesn't include path, so we use request.base_url.path + path = request.base_url.path + return f"{scheme}://{host}{path}" + except Exception as e: + logger.warning(f"Failed to parse Forwarded header: {e}") + + # Fall back to legacy X-Forwarded-* headers + forwarded_host = request.headers.get("X-Forwarded-Host") + forwarded_proto = request.headers.get("X-Forwarded-Proto") + forwarded_path = request.headers.get("X-Forwarded-Path") + + if forwarded_host: + # Use forwarded headers to reconstruct the original client URL + scheme = forwarded_proto or request.url.scheme + netloc = forwarded_host + # Use forwarded path if available, otherwise use request base URL path + path = forwarded_path or request.base_url.path + else: + # Fall back to the request's base URL if no forwarded headers + scheme = request.url.scheme + netloc = request.url.netloc + path = request.base_url.path + + return f"{scheme}://{netloc}{path}" diff --git a/tests/test_process_links.py b/tests/test_process_links.py index 43752b4..02dce87 100644 --- a/tests/test_process_links.py +++ b/tests/test_process_links.py @@ -1,4 +1,4 @@ -"""Tests for ProcessLinksMiddleware.""" +"""Tests for ProcessLinksMiddleware - Refactored with parametrization.""" import pytest from starlette.requests import Request @@ -6,20 +6,87 @@ from stac_auth_proxy.middleware.ProcessLinksMiddleware import ProcessLinksMiddleware -@pytest.fixture -def middleware(): - """Create a test instance of the middleware.""" - return ProcessLinksMiddleware( - app=None, # We don't need the actual app for these tests +@pytest.mark.parametrize( + "content_type,should_transform", + [ + ("application/json", True), + ("application/geo+json", True), + ("text/html", False), + ("text/plain", False), + ("application/xml", False), + ], +) +def test_should_transform_response_content_types(content_type, should_transform): + """Test that only JSON responses are transformed.""" + middleware = ProcessLinksMiddleware( + app=None, upstream_url="http://upstream.example.com/api", root_path="/proxy", ) + request_scope = { + "type": "http", + "path": "/test", + "headers": [ + (b"host", b"proxy.example.com"), + (b"content-type", content_type.encode()), + ], + } + assert ( + middleware.should_transform_response(Request(request_scope), request_scope) + == should_transform + ) -@pytest.fixture -def request_scope(): - """Create a test request scope.""" - return { +@pytest.mark.parametrize( + "upstream_url,root_path,input_links,expected_links", + [ + # Basic proxy links with upstream path + ( + "http://upstream.example.com/api", + "/proxy", + [ + {"rel": "self", "href": "http://proxy.example.com/api/collections"}, + {"rel": "root", "href": "http://proxy.example.com/api"}, + ], + [ + "http://proxy.example.com/proxy/collections", + "http://proxy.example.com/proxy", + ], + ), + # Proxy links without upstream path + ( + "http://upstream.example.com", + "/proxy", + [ + {"rel": "self", "href": "http://proxy.example.com/collections"}, + {"rel": "root", "href": "http://proxy.example.com/"}, + ], + [ + "http://proxy.example.com/proxy/collections", + "http://proxy.example.com/proxy/", + ], + ), + # Proxy links without root path + ( + "http://upstream.example.com/api", + None, + [ + {"rel": "self", "href": "http://proxy.example.com/api/collections"}, + {"rel": "root", "href": "http://proxy.example.com/api"}, + ], + [ + "http://proxy.example.com/collections", + "http://proxy.example.com", + ], + ), + ], +) +def test_transform_proxy_links(upstream_url, root_path, input_links, expected_links): + """Test transforming proxy links with different configurations.""" + middleware = ProcessLinksMiddleware( + app=None, upstream_url=upstream_url, root_path=root_path + ) + request_scope = { "type": "http", "path": "/test", "headers": [ @@ -28,177 +95,505 @@ def request_scope(): ], } + data = {"links": input_links} + transformed = middleware.transform_json(data, Request(request_scope)) + + for i, expected in enumerate(expected_links): + assert transformed["links"][i]["href"] == expected + + +@pytest.mark.parametrize( + "upstream_url,root_path,input_links,expected_links", + [ + # Upstream links with upstream path + ( + "http://upstream.example.com/api", + "/proxy", + [ + {"rel": "self", "href": "http://upstream.example.com/api/collections"}, + {"rel": "root", "href": "http://upstream.example.com/api"}, + { + "rel": "items", + "href": "http://upstream.example.com/api/collections/test/items", + }, + ], + [ + "http://proxy.example.com/proxy/collections", + "http://proxy.example.com/proxy", + "http://proxy.example.com/proxy/collections/test/items", + ], + ), + # Upstream links without upstream path + ( + "http://upstream.example.com", + "/proxy", + [ + {"rel": "self", "href": "http://upstream.example.com/collections"}, + {"rel": "root", "href": "http://upstream.example.com/"}, + ], + [ + "http://proxy.example.com/proxy/collections", + "http://proxy.example.com/proxy/", + ], + ), + # Upstream links without root path + ( + "http://upstream.example.com/api", + None, + [ + {"rel": "self", "href": "http://upstream.example.com/api/collections"}, + {"rel": "root", "href": "http://upstream.example.com/api"}, + ], + [ + "http://proxy.example.com/collections", + "http://proxy.example.com", + ], + ), + ], +) +def test_transform_upstream_links(upstream_url, root_path, input_links, expected_links): + """Test transforming upstream links with different configurations.""" + middleware = ProcessLinksMiddleware( + app=None, upstream_url=upstream_url, root_path=root_path + ) + request_scope = { + "type": "http", + "path": "/test", + "headers": [ + (b"host", b"proxy.example.com"), + (b"content-type", b"application/json"), + ], + } -def test_should_transform_response_json(middleware, request_scope): - """Test that JSON responses are transformed.""" - request = Request(request_scope) - assert middleware.should_transform_response(request, request_scope) - - -def test_should_transform_response_geojson(middleware, request_scope): - """Test that GeoJSON responses are transformed.""" - request_scope["headers"] = [ - (b"host", b"proxy.example.com"), - (b"content-type", b"application/geo+json"), - ] - request = Request(request_scope) - assert middleware.should_transform_response(request, request_scope) + data = {"links": input_links} + transformed = middleware.transform_json(data, Request(request_scope)) + + for i, expected in enumerate(expected_links): + assert transformed["links"][i]["href"] == expected + + +@pytest.mark.parametrize( + "upstream_url,root_path,port,input_links,expected_links", + [ + # Different ports + ( + "http://upstream.example.com:8080/api", + "/proxy", + 3000, + [ + { + "rel": "self", + "href": "http://upstream.example.com:8080/api/collections", + }, + ], + [ + "http://proxy.example.com:3000/proxy/collections", + ], + ), + ], +) +def test_transform_upstream_links_with_ports( + upstream_url, root_path, port, input_links, expected_links +): + """Test transforming upstream links with different ports.""" + middleware = ProcessLinksMiddleware( + app=None, upstream_url=upstream_url, root_path=root_path + ) + request_scope = { + "type": "http", + "path": "/test", + "headers": [ + (b"host", f"proxy.example.com:{port}".encode()), + (b"content-type", b"application/json"), + ], + } + data = {"links": input_links} + transformed = middleware.transform_json(data, Request(request_scope)) -def test_should_transform_response_non_json(middleware, request_scope): - """Test that non-JSON responses are not transformed.""" - request_scope["headers"] = [ - (b"host", b"proxy.example.com"), - (b"content-type", b"text/plain"), - ] - request = Request(request_scope) - assert not middleware.should_transform_response(request, request_scope) + for i, expected in enumerate(expected_links): + assert transformed["links"][i]["href"] == expected -def test_transform_json_with_upstream_path(middleware, request_scope): - """Test transforming links with upstream URL path.""" - request = Request(request_scope) +def test_transform_json_different_host(): + """Test that links with different hostnames are not transformed.""" + middleware = ProcessLinksMiddleware( + app=None, + upstream_url="http://upstream.example.com/api", + root_path="/proxy", + ) + request_scope = { + "type": "http", + "path": "/test", + "headers": [ + (b"host", b"proxy.example.com"), + (b"content-type", b"application/json"), + ], + } data = { "links": [ - {"rel": "self", "href": "http://proxy.example.com/api/collections"}, - {"rel": "root", "href": "http://proxy.example.com/api"}, + {"rel": "self", "href": "http://other.example.com/api/collections"}, + {"rel": "root", "href": "http://other.example.com/api"}, ] } - transformed = middleware.transform_json(data, request) + transformed = middleware.transform_json(data, Request(request_scope)) - assert ( - transformed["links"][0]["href"] == "http://proxy.example.com/proxy/collections" - ) - assert transformed["links"][1]["href"] == "http://proxy.example.com/proxy" + assert transformed["links"][0]["href"] == "http://other.example.com/api/collections" + assert transformed["links"][1]["href"] == "http://other.example.com/api" -def test_transform_json_without_upstream_path(middleware, request_scope): - """Test transforming links without upstream URL path.""" +def test_transform_json_invalid_link(): + """Test that invalid links are handled gracefully.""" middleware = ProcessLinksMiddleware( - app=None, upstream_url="http://upstream.example.com", root_path="/proxy" + app=None, + upstream_url="http://upstream.example.com/api", + root_path="/proxy", ) - request = Request(request_scope) + request_scope = { + "type": "http", + "path": "/test", + "headers": [ + (b"host", b"proxy.example.com"), + (b"content-type", b"application/json"), + ], + } data = { "links": [ - {"rel": "self", "href": "http://proxy.example.com/collections"}, - {"rel": "root", "href": "http://proxy.example.com/"}, + {"rel": "self", "href": "not-a-url"}, + {"rel": "root", "href": "http://proxy.example.com/api"}, ] } - transformed = middleware.transform_json(data, request) + transformed = middleware.transform_json(data, Request(request_scope)) - assert ( - transformed["links"][0]["href"] == "http://proxy.example.com/proxy/collections" - ) - assert transformed["links"][1]["href"] == "http://proxy.example.com/proxy/" + assert transformed["links"][0]["href"] == "not-a-url" + assert transformed["links"][1]["href"] == "http://proxy.example.com/proxy" -def test_transform_json_without_root_path(middleware, request_scope): - """Test transforming links without root path.""" +def test_transform_json_nested_links(): + """Test transforming links in nested STAC objects.""" middleware = ProcessLinksMiddleware( - app=None, upstream_url="http://upstream.example.com/api", root_path=None + app=None, + upstream_url="http://upstream.example.com/api", + root_path="/proxy", ) - request = Request(request_scope) + request_scope = { + "type": "http", + "path": "/test", + "headers": [ + (b"host", b"proxy.example.com"), + (b"content-type", b"application/json"), + ], + } data = { "links": [ {"rel": "self", "href": "http://proxy.example.com/api/collections"}, - {"rel": "root", "href": "http://proxy.example.com/api"}, - ] + ], + "collections": [ + { + "id": "test-collection", + "links": [ + { + "rel": "items", + "href": "http://proxy.example.com/api/collections/test-collection/items", + }, + ], + } + ], } - transformed = middleware.transform_json(data, request) + transformed = middleware.transform_json(data, Request(request_scope)) - assert transformed["links"][0]["href"] == "http://proxy.example.com/collections" - assert transformed["links"][1]["href"] == "http://proxy.example.com" + # Top-level links should be transformed + assert ( + transformed["links"][0]["href"] == "http://proxy.example.com/proxy/collections" + ) + # Nested links should also be transformed + assert ( + transformed["collections"][0]["links"][0]["href"] + == "http://proxy.example.com/proxy/collections/test-collection/items" + ) -def test_transform_json_different_host(middleware, request_scope): - """Test that links with different hostnames are not transformed.""" - request = Request(request_scope) - data = { - "links": [ - {"rel": "self", "href": "http://other.example.com/api/collections"}, - {"rel": "root", "href": "http://other.example.com/api"}, - ] +def test_transform_without_prefix(): + """Test transforming links without root_path prefix.""" + middleware = ProcessLinksMiddleware( + app=None, + upstream_url="http://upstream.example.com/api", + root_path=None, + ) + request_scope = { + "type": "http", + "path": "/test", + "headers": [ + (b"host", b"proxy.example.com"), + (b"content-type", b"application/json"), + ], } - transformed = middleware.transform_json(data, request) - - assert transformed["links"][0]["href"] == "http://other.example.com/api/collections" - assert transformed["links"][1]["href"] == "http://other.example.com/api" - - -def test_transform_json_invalid_link(middleware, request_scope): - """Test that invalid links are handled gracefully.""" - request = Request(request_scope) - data = { "links": [ - {"rel": "self", "href": "not-a-url"}, - {"rel": "root", "href": "http://proxy.example.com/api"}, + {"rel": "self", "href": "http://proxy.example.com/api/collections"}, + {"rel": "data", "href": "http://proxy.example.com/collections"}, ] } + transformed = middleware.transform_json(data, Request(request_scope)) + assert transformed["links"][0]["href"] == "http://proxy.example.com/collections" + assert transformed["links"][1]["href"] == "http://proxy.example.com/collections" + + +@pytest.mark.parametrize( + "upstream_url,root_path,input_links,expected_links", + [ + # Upstream links with upstream path + ( + "http://upstream.example.com/api", + "/proxy", + [ + {"rel": "self", "href": "http://upstream.example.com/api/collections"}, + {"rel": "root", "href": "http://upstream.example.com/api"}, + { + "rel": "items", + "href": "http://upstream.example.com/api/collections/test/items", + }, + ], + [ + "http://proxy.example.com/proxy/collections", + "http://proxy.example.com/proxy", + "http://proxy.example.com/proxy/collections/test/items", + ], + ), + # Upstream links without upstream path + ( + "http://upstream.example.com", + "/proxy", + [ + {"rel": "self", "href": "http://upstream.example.com/collections"}, + {"rel": "root", "href": "http://upstream.example.com/"}, + {"rel": "root", "href": "http://upstream.example.com/other/path"}, + ], + [ + "http://proxy.example.com/proxy/collections", + "http://proxy.example.com/proxy/", + "http://proxy.example.com/proxy/other/path", + ], + ), + # Upstream links without root path + ( + "http://upstream.example.com/api", + None, + [ + {"rel": "self", "href": "http://upstream.example.com/api/collections"}, + {"rel": "root", "href": "http://upstream.example.com/api"}, + {"rel": "root", "href": "http://upstream.example.com/other/path"}, + ], + [ + "http://proxy.example.com/collections", + "http://proxy.example.com", + # Upstream links without matching root path should be ignored + "http://upstream.example.com/other/path", + ], + ), + ], +) +def test_transform_mixed_links(upstream_url, root_path, input_links, expected_links): + """Test transforming a mix of proxy links and upstream links.""" + middleware = ProcessLinksMiddleware( + app=None, + upstream_url=upstream_url, + root_path=root_path, + ) + request_scope = { + "type": "http", + "path": "/test", + "headers": [ + (b"host", b"proxy.example.com"), + (b"content-type", b"application/json"), + ], + } - transformed = middleware.transform_json(data, request) + transformed = middleware.transform_json( + { + "links": input_links, + }, + Request(request_scope), + ) - assert transformed["links"][0]["href"] == "not-a-url" - assert transformed["links"][1]["href"] == "http://proxy.example.com/proxy" + for i, expected in enumerate(expected_links): + assert transformed["links"][i]["href"] == expected -def test_transform_json_nested_links(middleware, request_scope): - """Test transforming links in nested STAC objects.""" - request = Request(request_scope) +def test_transform_upstream_links_nested_objects(): + """Test transforming upstream links in nested STAC objects.""" + middleware = ProcessLinksMiddleware( + app=None, + upstream_url="http://upstream.example.com/api", + root_path="/proxy", + ) + request_scope = { + "type": "http", + "path": "/test", + "headers": [ + (b"host", b"proxy.example.com"), + (b"content-type", b"application/json"), + ], + } data = { "links": [ - {"rel": "self", "href": "http://proxy.example.com/api"}, + {"rel": "self", "href": "http://upstream.example.com/api"}, ], "collections": [ { "id": "test-collection", "links": [ - { - "rel": "self", - "href": "http://proxy.example.com/api/collections/test-collection", - }, { "rel": "items", - "href": "http://proxy.example.com/api/collections/test-collection/items", + "href": "http://upstream.example.com/api/collections/test-collection/items", }, ], } ], } - transformed = middleware.transform_json(data, request) + transformed = middleware.transform_json(data, Request(request_scope)) + # Top-level links should be transformed assert transformed["links"][0]["href"] == "http://proxy.example.com/proxy" + + # Nested links should also be transformed assert ( transformed["collections"][0]["links"][0]["href"] - == "http://proxy.example.com/proxy/collections/test-collection" - ) - assert ( - transformed["collections"][0]["links"][1]["href"] == "http://proxy.example.com/proxy/collections/test-collection/items" ) -def test_transform_without_prefix(request_scope): - """Sometimes the upstream url will have a path, but the links won't.""" +@pytest.mark.parametrize( + "headers,expected_base_url", + [ + # X-Forwarded-* headers + ( + [ + (b"host", b"internal-proxy:8080"), + (b"content-type", b"application/json"), + (b"x-forwarded-host", b"api.example.com"), + (b"x-forwarded-proto", b"https"), + (b"x-forwarded-path", b"/api/v1"), + ], + "https://api.example.com", + ), + # Partial X-Forwarded-* headers + ( + [ + (b"host", b"internal-proxy:8080"), + (b"content-type", b"application/json"), + (b"x-forwarded-host", b"api.example.com"), + ], + "http://api.example.com", # Falls back to request scheme + ), + # No forwarded headers + ( + [ + (b"host", b"proxy.example.com"), + (b"content-type", b"application/json"), + ], + "http://proxy.example.com", + ), + # Standard Forwarded header + ( + [ + (b"host", b"internal-proxy:8080"), + (b"content-type", b"application/json"), + ( + b"forwarded", + b"for=192.0.2.43; by=203.0.113.60; proto=https; host=api.example.com", + ), + ], + "https://api.example.com", + ), + # Forwarded header with multiple proxies + ( + [ + (b"host", b"internal-proxy:8080"), + (b"content-type", b"application/json"), + ( + b"forwarded", + b"for=192.0.2.43, for=198.51.100.17; by=203.0.113.60; proto=https; host=api.example.com", + ), + ], + "https://api.example.com", + ), + # Forwarded header with quoted values + ( + [ + (b"host", b"internal-proxy:8080"), + (b"content-type", b"application/json"), + ( + b"forwarded", + b'for="192.0.2.43"; by="203.0.113.60"; proto="https"; host="api.example.com"', + ), + ], + "https://api.example.com", + ), + # Forwarded header with partial info + ( + [ + (b"host", b"internal-proxy:8080"), + (b"content-type", b"application/json"), + (b"forwarded", b"for=192.0.2.43; host=api.example.com"), + ], + "http://api.example.com", # Falls back to request scheme + ), + # Forwarded header priority over X-Forwarded-* + ( + [ + (b"host", b"internal-proxy:8080"), + (b"content-type", b"application/json"), + (b"x-forwarded-host", b"x-forwarded.example.com"), + (b"x-forwarded-proto", b"http"), + (b"forwarded", b"for=192.0.2.43; proto=https; host=api.example.com"), + ], + "https://api.example.com", + ), + # Malformed Forwarded header falls back to X-Forwarded-* + ( + [ + (b"host", b"internal-proxy:8080"), + (b"content-type", b"application/json"), + (b"x-forwarded-host", b"api.example.com"), + (b"x-forwarded-proto", b"https"), + (b"forwarded", b"malformed header content"), + ], + "https://api.example.com", + ), + ], +) +def test_transform_with_forwarded_headers(headers, expected_base_url): + """Test transforming links with various forwarded header scenarios.""" middleware = ProcessLinksMiddleware( - app=None, upstream_url="http://upstream.example.com/prod/", root_path="" + app=None, upstream_url="http://upstream.example.com/api", root_path="/proxy" ) - request = Request(request_scope) + request_scope = { + "type": "http", + "path": "/test", + "headers": headers, + } data = { "links": [ - {"rel": "data", "href": "http://proxy.example.com/collections"}, + {"rel": "self", "href": "http://upstream.example.com/api/collections"}, + {"rel": "root", "href": "http://upstream.example.com/api"}, ] } - transformed = middleware.transform_json(data, request) - assert transformed["links"][0]["href"] == "http://proxy.example.com/collections" + + transformed = middleware.transform_json(data, Request(request_scope)) + + # Should use the forwarded headers to construct the correct client URL + # but not include the forwarded path in the response URLs + assert transformed["links"][0]["href"] == f"{expected_base_url}/proxy/collections" + assert transformed["links"][1]["href"] == f"{expected_base_url}/proxy" diff --git a/tests/test_utils.py b/tests/test_utils.py index 1f9d354..1a94891 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,7 +3,11 @@ import pytest from utils import parse_query_string -from stac_auth_proxy.utils.requests import extract_variables +from stac_auth_proxy.utils.requests import ( + extract_variables, + get_base_url, + parse_forwarded_header, +) @pytest.mark.parametrize( @@ -35,3 +39,93 @@ def test_extract_variables(url, expected): def test_parse_query_string(query, expected): """Validate test helper for parsing query strings.""" assert parse_query_string(query) == expected + + +@pytest.mark.parametrize( + "header, expected", + ( + # Basic Forwarded header parsing + ( + "for=192.0.2.43; by=203.0.113.60; proto=https; host=api.example.com", + { + "for": "192.0.2.43", + "by": "203.0.113.60", + "proto": "https", + "host": "api.example.com", + }, + ), + # Multiple for values - should only take the first + ( + "for=192.0.2.43, for=198.51.100.17; by=203.0.113.60; proto=https; host=api.example.com", + { + "for": "192.0.2.43", + "by": "203.0.113.60", + "proto": "https", + "host": "api.example.com", + }, + ), + # Quoted values + ( + 'for="192.0.2.43"; by="203.0.113.60"; proto="https"; host="api.example.com"', + { + "for": "192.0.2.43", + "by": "203.0.113.60", + "proto": "https", + "host": "api.example.com", + }, + ), + # Malformed content + ("malformed header content", {}), + # Empty content + ("", {}), + ), +) +def test_parse_forwarded_header(header, expected): + """Test Forwarded header parsing with various scenarios.""" + result = parse_forwarded_header(header) + assert result == expected + + +@pytest.mark.parametrize( + "headers, expected_url", + ( + # Forwarded header + ( + [ + (b"host", b"internal-proxy:8080"), + (b"forwarded", b"for=192.0.2.43; proto=https; host=api.example.com"), + ], + "https://api.example.com/", + ), + # X-Forwarded-* headers + ( + [ + (b"host", b"internal-proxy:8080"), + (b"x-forwarded-host", b"api.example.com"), + (b"x-forwarded-proto", b"https"), + ], + "https://api.example.com/", + ), + # No forwarded headers + ( + [ + (b"host", b"proxy.example.com"), + ], + "http://proxy.example.com/", + ), + ), +) +def test_get_base_url(headers, expected_url): + """Test get_base_url with various header configurations.""" + from starlette.requests import Request + + scope = { + "type": "http", + "method": "GET", + "path": "/test", + "headers": headers, + } + request = Request(scope) + + result = get_base_url(request) + assert result == expected_url