Skip to content

Commit 7c0bc6d

Browse files
committed
fink: enhance ProcessLinksMiddleware with base URL handling and link transformation
- Added `get_base_url` utility to reconstruct the client's base URL from forwarded headers. - Updated `ProcessLinksMiddleware` to utilize the new utility for transforming links in responses. - Improved link transformation logic to handle various scenarios, including different hostnames and ports. - Refactored tests for `ProcessLinksMiddleware` to cover new functionality and edge cases.
1 parent 430ea66 commit 7c0bc6d

File tree

4 files changed

+693
-106
lines changed

4 files changed

+693
-106
lines changed

src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from starlette.types import ASGIApp, Scope
1212

1313
from ..utils.middleware import JsonResponseMiddleware
14+
from ..utils.requests import get_base_url
1415
from ..utils.stac import get_links
1516

1617
logger = logging.getLogger(__name__)
@@ -40,6 +41,11 @@ def should_transform_response(self, request: Request, scope: Scope) -> bool:
4041

4142
def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]:
4243
"""Update links in the response to include root_path."""
44+
# Get the client's actual base URL (accounting for load balancers/proxies)
45+
req_base_url = get_base_url(request)
46+
parsed_req_url = urlparse(req_base_url)
47+
parsed_upstream_url = urlparse(self.upstream_url)
48+
4349
for link in get_links(data):
4450
href = link.get("href")
4551
if not href:
@@ -48,12 +54,25 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An
4854
try:
4955
parsed_link = urlparse(href)
5056

51-
# Ignore links that are not for this proxy
52-
if parsed_link.netloc != request.headers.get("host"):
57+
if parsed_link.netloc not in [
58+
parsed_req_url.netloc,
59+
parsed_upstream_url.netloc,
60+
]:
61+
logger.warning(
62+
"Ignoring link %s because it is not for an endpoint behind this proxy (%s or %s)",
63+
href,
64+
parsed_req_url.netloc,
65+
parsed_upstream_url.netloc,
66+
)
5367
continue
5468

55-
# Remove the upstream_url path from the link if it exists
56-
parsed_upstream_url = urlparse(self.upstream_url)
69+
if parsed_link.netloc == parsed_upstream_url.netloc:
70+
# Replace the upstream host with the client's host
71+
parsed_link = parsed_link._replace(
72+
netloc=parsed_req_url.netloc
73+
)._replace(scheme=parsed_req_url.scheme)
74+
75+
# Rewrite the link path
5776
if parsed_upstream_url.path != "/" and parsed_link.path.startswith(
5877
parsed_upstream_url.path
5978
):
@@ -68,6 +87,7 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An
6887
)
6988

7089
link["href"] = urlunparse(parsed_link)
90+
7191
except Exception as e:
7292
logger.error(
7393
"Failed to parse link href %r, (ignoring): %s", href, str(e)

src/stac_auth_proxy/utils/requests.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
"""Utility functions for working with HTTP requests."""
22

33
import json
4+
import logging
45
import re
56
from dataclasses import dataclass, field
6-
from typing import Optional, Sequence
7+
from typing import Dict, Optional, Sequence
78
from urllib.parse import urlparse
89

10+
from starlette.requests import Request
11+
912
from ..config import EndpointMethods
1013

14+
logger = logging.getLogger(__name__)
15+
1116

1217
def extract_variables(url: str) -> dict:
1318
"""
@@ -90,3 +95,110 @@ def build_server_timing_header(
9095
if current_value:
9196
return f"{current_value}, {metric}"
9297
return metric
98+
99+
100+
def parse_forwarded_header(forwarded_header: str) -> Dict[str, str]:
101+
"""
102+
Parse the Forwarded header according to RFC 7239.
103+
104+
Args:
105+
forwarded_header: The Forwarded header value
106+
107+
Returns:
108+
Dictionary containing parsed forwarded information (proto, host, for, by, etc.)
109+
110+
Example:
111+
>>> parse_forwarded_header("for=192.0.2.43; by=203.0.113.60; proto=https; host=api.example.com")
112+
{'for': '192.0.2.43', 'by': '203.0.113.60', 'proto': 'https', 'host': 'api.example.com'}
113+
114+
"""
115+
# Forwarded header format: "for=192.0.2.43, for=198.51.100.17; by=203.0.113.60; proto=https; host=example.com"
116+
# The format is: for=value1, for=value2; by=value; proto=value; host=value
117+
# We need to parse all the key=value pairs, taking the first 'for' value
118+
forwarded_info = {}
119+
120+
try:
121+
# Parse all key=value pairs separated by semicolons
122+
for pair in forwarded_header.split(";"):
123+
pair = pair.strip()
124+
if "=" in pair:
125+
key, value = pair.split("=", 1)
126+
key = key.strip()
127+
value = value.strip().strip('"')
128+
129+
# For 'for' field, only take the first value if there are multiple
130+
if key == "for" and key not in forwarded_info:
131+
# Extract the first for value (before comma if present)
132+
first_for_value = value.split(",")[0].strip()
133+
forwarded_info[key] = first_for_value
134+
elif key != "for":
135+
# For other fields, just use the value as-is
136+
forwarded_info[key] = value
137+
except Exception as e:
138+
logger.warning(f"Failed to parse Forwarded header '{forwarded_header}': {e}")
139+
return {}
140+
141+
return forwarded_info
142+
143+
144+
def get_base_url(request: Request) -> str:
145+
"""
146+
Get the request's base URL, accounting for forwarded headers from load balancers/proxies.
147+
148+
This function handles both the standard Forwarded header (RFC 7239) and legacy
149+
X-Forwarded-* headers to reconstruct the original client URL when the service
150+
is deployed behind load balancers or reverse proxies.
151+
152+
Args:
153+
request: The Starlette request object
154+
155+
Returns:
156+
The reconstructed client base URL
157+
158+
Example:
159+
>>> # With Forwarded header
160+
>>> request.headers = {"Forwarded": "for=192.0.2.43; proto=https; host=api.example.com"}
161+
>>> get_base_url(request)
162+
"https://api.example.com/"
163+
164+
>>> # With X-Forwarded-* headers
165+
>>> request.headers = {"X-Forwarded-Host": "api.example.com", "X-Forwarded-Proto": "https"}
166+
>>> get_base_url(request)
167+
"https://api.example.com/"
168+
169+
"""
170+
# Check for standard Forwarded header first (RFC 7239)
171+
forwarded_header = request.headers.get("Forwarded")
172+
if forwarded_header:
173+
try:
174+
forwarded_info = parse_forwarded_header(forwarded_header)
175+
# Only use Forwarded header if we successfully parsed it and got useful info
176+
if forwarded_info and (
177+
"proto" in forwarded_info or "host" in forwarded_info
178+
):
179+
scheme = forwarded_info.get("proto", request.url.scheme)
180+
host = forwarded_info.get("host", request.url.netloc)
181+
# Note: Forwarded header doesn't include path, so we use request.base_url.path
182+
path = request.base_url.path
183+
return f"{scheme}://{host}{path}"
184+
except Exception as e:
185+
logger.warning(f"Failed to parse Forwarded header: {e}")
186+
187+
# Fall back to legacy X-Forwarded-* headers
188+
forwarded_host = request.headers.get("X-Forwarded-Host")
189+
forwarded_proto = request.headers.get("X-Forwarded-Proto")
190+
forwarded_path = request.headers.get("X-Forwarded-Path")
191+
192+
if forwarded_host:
193+
# Use forwarded headers to reconstruct the original client URL
194+
scheme = forwarded_proto or request.url.scheme
195+
netloc = forwarded_host
196+
# Use forwarded path if available, otherwise use request base URL path
197+
path = forwarded_path or request.base_url.path
198+
else:
199+
# Fall back to the request's base URL if no forwarded headers
200+
scheme = request.url.scheme
201+
netloc = request.url.netloc
202+
path = request.base_url.path
203+
204+
return f"{scheme}://{netloc}{path}"

0 commit comments

Comments
 (0)