Skip to content

Commit 44cff51

Browse files
committed
Merge branch 'main' into example/opa
2 parents 1916e9c + 046c731 commit 44cff51

File tree

14 files changed

+519
-82
lines changed

14 files changed

+519
-82
lines changed

README.md

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
1111
STAC Auth Proxy is a proxy API that mediates between the client and your internally accessible STAC API to provide flexible authentication, authorization, and content-filtering mechanisms.
1212

13-
## Features
13+
## Features
1414

15-
- 🔐 Authentication: Selectively apply OIDC auth to some or all endpoints & methods
16-
- 🎟️ Content Filtering: Apply CQL2 filters to client requests, filtering API content based on user context
17-
- 📖 OpenAPI Augmentation: Update [OpenAPI](https://swagger.io/specification/) with security requirements, keeping auto-generated docs/UIs accurate (e.g. [Swagger UI](https://swagger.io/tools/swagger-ui/))
15+
- 🔐 Authentication: Selectively apply [OpenID Connect (OIDC)](https://openid.net/developers/how-connect-works/) auth*n token validation & optional scope requirements to some or all endpoints & methods
16+
- 🛂 Content Filtering: Apply CQL2 filters to client requests, utilizing the [Filter Extension](https://github.com/stac-api-extensions/filter?tab=readme-ov-file) to filter API content based on user context
17+
- 🧩 Authentication Extension: Integrate the [Authentication Extension](https://github.com/stac-extensions/authentication) into API responses
18+
- 📘 OpenAPI Augmentation: Update API's [OpenAPI document](https://swagger.io/specification/) with security requirements, keeping auto-generated docs/UIs accurate (e.g. [Swagger UI](https://swagger.io/tools/swagger-ui/))
19+
- 🗜️ Response compression: Compress API responses via [`starlette-cramjam`](https://github.com/developmentseed/starlette-cramjam/)
1820

1921
## Usage
2022

@@ -75,6 +77,10 @@ The application is configurable via environment variables.
7577
- **Type:** boolean
7678
- **Required:** No, defaults to `true`
7779
- **Example:** `false`, `1`, `True`
80+
- **`ENABLE_COMPRESSION`**, enable response compression
81+
- **Type:** boolean
82+
- **Required:** No, defaults to `true`
83+
- **Example:** `false`, `1`, `True`
7884
- **`HEALTHZ_PREFIX`**, path prefix for health check endpoints
7985
- **Type:** string
8086
- **Required:** No, defaults to `/healthz`
@@ -115,6 +121,10 @@ The application is configurable via environment variables.
115121
"^/healthz": ["GET"]
116122
}
117123
```
124+
- **`ENABLE_AUTHENTICATION_EXTENSION`**, enable authentication extension in STAC API responses
125+
- **Type:** boolean
126+
- **Required:** No, defaults to `true`
127+
- **Example:** `false`, `1`, `True`
118128
- **`OPENAPI_SPEC_ENDPOINT`**, path of OpenAPI specification, used for augmenting spec response with auth configuration
119129
- **Type:** string or null
120130
- **Required:** No, defaults to `null` (disabled)

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@ classifiers = [
66
"License :: OSI Approved :: MIT License",
77
]
88
dependencies = [
9+
"boto3>=1.37.16",
910
"brotli>=1.1.0",
1011
"cql2>=0.3.6",
12+
"cryptography>=44.0.1",
1113
"fastapi>=0.115.5",
1214
"httpx[http2]>=0.28.0",
1315
"jinja2>=3.1.4",

src/stac_auth_proxy/app.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .middleware import (
1818
AddProcessTimeHeaderMiddleware,
1919
ApplyCql2FilterMiddleware,
20+
AuthenticationExtensionMiddleware,
2021
BuildCql2FilterMiddleware,
2122
EnforceAuthMiddleware,
2223
OpenApiMiddleware,
@@ -86,6 +87,14 @@ async def lifespan(app: FastAPI):
8687
#
8788
# Middleware (order is important, last added = first to run)
8889
#
90+
if settings.enable_authentication_extension:
91+
app.add_middleware(
92+
AuthenticationExtensionMiddleware,
93+
default_public=settings.default_public,
94+
public_endpoints=settings.public_endpoints,
95+
private_endpoints=settings.private_endpoints,
96+
)
97+
8998
if settings.openapi_spec_endpoint:
9099
app.add_middleware(
91100
OpenApiMiddleware,
@@ -105,9 +114,10 @@ async def lifespan(app: FastAPI):
105114
items_filter=settings.items_filter(),
106115
)
107116

108-
app.add_middleware(
109-
CompressionMiddleware,
110-
)
117+
if settings.enable_compression:
118+
app.add_middleware(
119+
CompressionMiddleware,
120+
)
111121

112122
app.add_middleware(
113123
AddProcessTimeHeaderMiddleware,

src/stac_auth_proxy/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ class Settings(BaseSettings):
4040

4141
wait_for_upstream: bool = True
4242
check_conformance: bool = True
43-
44-
# Endpoints
43+
enable_compression: bool = True
44+
enable_authentication_extension: bool = True
4545
healthz_prefix: str = Field(pattern=_PREFIX_PATTERN, default="/healthz")
4646
openapi_spec_endpoint: Optional[str] = Field(pattern=_PREFIX_PATTERN, default=None)
4747

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
"""Middleware to add auth information to item response served by upstream API."""
2+
3+
import logging
4+
import re
5+
from dataclasses import dataclass, field
6+
from itertools import chain
7+
from typing import Any
8+
from urllib.parse import urlparse
9+
10+
from starlette.datastructures import Headers
11+
from starlette.requests import Request
12+
from starlette.types import ASGIApp
13+
14+
from ..config import EndpointMethods
15+
from ..utils.middleware import JsonResponseMiddleware
16+
from ..utils.requests import find_match
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
@dataclass
22+
class AuthenticationExtensionMiddleware(JsonResponseMiddleware):
23+
"""Middleware to add the authentication extension to the response."""
24+
25+
app: ASGIApp
26+
27+
default_public: bool
28+
private_endpoints: EndpointMethods
29+
public_endpoints: EndpointMethods
30+
31+
auth_scheme_name: str = "oauth"
32+
auth_scheme: dict[str, Any] = field(default_factory=dict)
33+
extension_url: str = (
34+
"https://stac-extensions.github.io/authentication/v1.1.0/schema.json"
35+
)
36+
37+
json_content_type_expr: str = r"application/(geo\+)?json"
38+
39+
state_key: str = "oidc_metadata"
40+
41+
def should_transform_response(
42+
self, request: Request, response_headers: Headers
43+
) -> bool:
44+
"""Determine if the response should be transformed."""
45+
# Match STAC catalog, collection, or item URLs with a single regex
46+
return all(
47+
[
48+
re.match(
49+
# catalog, collections, collection, items, item, search
50+
r"^(/|/collections(/[^/]+(/items(/[^/]+)?)?)?|/search)$",
51+
request.url.path,
52+
),
53+
re.match(
54+
self.json_content_type_expr,
55+
response_headers.get("content-type", ""),
56+
),
57+
]
58+
)
59+
60+
def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]:
61+
"""Augment the STAC Item with auth information."""
62+
extensions = data.setdefault("stac_extensions", [])
63+
if self.extension_url not in extensions:
64+
extensions.append(self.extension_url)
65+
66+
# auth:schemes
67+
# ---
68+
# A property that contains all of the scheme definitions used by Assets and
69+
# Links in the STAC Item or Collection.
70+
# - Catalogs
71+
# - Collections
72+
# - Item Properties
73+
74+
oidc_metadata = getattr(request.state, self.state_key, {})
75+
if not oidc_metadata:
76+
logger.error(
77+
"OIDC metadata not found in scope. Skipping authentication extension."
78+
)
79+
return data
80+
81+
scheme_loc = data["properties"] if "properties" in data else data
82+
schemes = scheme_loc.setdefault("auth:schemes", {})
83+
schemes[self.auth_scheme_name] = {
84+
"type": "oauth2",
85+
"description": "requires an authentication bearertoken",
86+
"flows": {
87+
"authorizationCode": {
88+
"authorizationUrl": oidc_metadata["authorization_endpoint"],
89+
"tokenUrl": oidc_metadata.get("token_endpoint"),
90+
"scopes": {
91+
k: k for k in sorted(oidc_metadata.get("scopes_supported", []))
92+
},
93+
},
94+
},
95+
}
96+
97+
# auth:refs
98+
# ---
99+
# Annotate links with "auth:refs": [auth_scheme]
100+
links = chain(
101+
# Item/Collection
102+
data.get("links", []),
103+
# Collections/Items/Search
104+
(
105+
link
106+
for prop in ["features", "collections"]
107+
for object_with_links in data.get(prop, [])
108+
for link in object_with_links.get("links", [])
109+
),
110+
)
111+
for link in links:
112+
if "href" not in link:
113+
logger.warning("Link %s has no href", link)
114+
continue
115+
match = find_match(
116+
path=urlparse(link["href"]).path,
117+
method="GET",
118+
private_endpoints=self.private_endpoints,
119+
public_endpoints=self.public_endpoints,
120+
default_public=self.default_public,
121+
)
122+
if match.is_private:
123+
link.setdefault("auth:refs", []).append(self.auth_scheme_name)
124+
125+
return data

src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py

Lines changed: 58 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Middleware to enforce authentication."""
22

33
import logging
4-
from dataclasses import dataclass
4+
from dataclasses import dataclass, field
55
from typing import Annotated, Any, Optional, Sequence
66
from urllib.parse import urlparse, urlunparse
77

@@ -18,6 +18,53 @@
1818
logger = logging.getLogger(__name__)
1919

2020

21+
@dataclass
22+
class OidcService:
23+
"""OIDC configuration and JWKS client."""
24+
25+
oidc_config_url: HttpUrl
26+
jwks_client: jwt.PyJWKClient = field(init=False)
27+
metadata: dict[str, Any] = field(init=False)
28+
29+
def __post_init__(self) -> None:
30+
"""Initialize OIDC config and JWKS client."""
31+
logger.debug("Requesting OIDC config")
32+
origin_url = str(self.oidc_config_url)
33+
34+
try:
35+
response = httpx.get(origin_url)
36+
response.raise_for_status()
37+
self.metadata = response.json()
38+
assert self.metadata, "OIDC metadata is empty"
39+
40+
# NOTE: We manually replace the origin of the jwks_uri in the event that
41+
# the jwks_uri is not available from within the proxy.
42+
oidc_url = urlparse(origin_url)
43+
jwks_uri = urlunparse(
44+
urlparse(self.metadata["jwks_uri"])._replace(
45+
netloc=oidc_url.netloc, scheme=oidc_url.scheme
46+
)
47+
)
48+
if jwks_uri != self.metadata["jwks_uri"]:
49+
logger.warning(
50+
"JWKS URI has been rewritten from %s to %s",
51+
self.metadata["jwks_uri"],
52+
jwks_uri,
53+
)
54+
self.jwks_client = jwt.PyJWKClient(jwks_uri)
55+
except httpx.HTTPStatusError as e:
56+
logger.error(
57+
"Received a non-200 response when fetching OIDC config: %s",
58+
e.response.text,
59+
)
60+
raise OidcFetchError(
61+
f"Request for OIDC config failed with status {e.response.status_code}"
62+
) from e
63+
except httpx.RequestError as e:
64+
logger.error("Error fetching OIDC config from %s: %s", origin_url, str(e))
65+
raise OidcFetchError(f"Request for OIDC config failed: {str(e)}") from e
66+
67+
2168
@dataclass
2269
class EnforceAuthMiddleware:
2370
"""Middleware to enforce authentication."""
@@ -26,56 +73,11 @@ class EnforceAuthMiddleware:
2673
private_endpoints: EndpointMethods
2774
public_endpoints: EndpointMethods
2875
default_public: bool
29-
3076
oidc_config_url: HttpUrl
3177
allowed_jwt_audiences: Optional[Sequence[str]] = None
32-
3378
state_key: str = "payload"
3479

35-
# Generated attributes
36-
_jwks_client: Optional[jwt.PyJWKClient] = None
37-
38-
@property
39-
def jwks_client(self) -> jwt.PyJWKClient:
40-
"""Get the OIDC configuration URL."""
41-
if not self._jwks_client:
42-
logger.debug("Requesting OIDC config")
43-
origin_url = str(self.oidc_config_url)
44-
45-
try:
46-
response = httpx.get(origin_url)
47-
response.raise_for_status()
48-
oidc_config = response.json()
49-
50-
# NOTE: We manually replace the origin of the jwks_uri in the event that
51-
# the jwks_uri is not available from within the proxy.
52-
oidc_url = urlparse(origin_url)
53-
jwks_uri = urlunparse(
54-
urlparse(oidc_config["jwks_uri"])._replace(
55-
netloc=oidc_url.netloc, scheme=oidc_url.scheme
56-
)
57-
)
58-
if jwks_uri != oidc_config["jwks_uri"]:
59-
logger.warning(
60-
"JWKS URI has been rewritten from %s to %s",
61-
oidc_config["jwks_uri"],
62-
jwks_uri,
63-
)
64-
self._jwks_client = jwt.PyJWKClient(jwks_uri)
65-
except httpx.HTTPStatusError as e:
66-
logger.error(
67-
"Received a non-200 response when fetching OIDC config: %s",
68-
e.response.text,
69-
)
70-
raise OidcFetchError(
71-
f"Request for OIDC config failed with status {e.response.status_code}"
72-
) from e
73-
except httpx.RequestError as e:
74-
logger.error(
75-
"Error fetching OIDC config from %s: %s", origin_url, str(e)
76-
)
77-
raise OidcFetchError(f"Request for OIDC config failed: {str(e)}") from e
78-
return self._jwks_client
80+
_oidc_config: Optional[OidcService] = None
7981

8082
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
8183
"""Enforce authentication."""
@@ -107,6 +109,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
107109
self.state_key,
108110
payload,
109111
)
112+
setattr(request.state, "oidc_metadata", self.oidc_config.metadata)
110113
return await self.app(scope, receive, send)
111114

112115
def validate_token(
@@ -137,7 +140,7 @@ def validate_token(
137140

138141
# Parse & validate token
139142
try:
140-
key = self.jwks_client.get_signing_key_from_jwt(token).key
143+
key = self.oidc_config.jwks_client.get_signing_key_from_jwt(token).key
141144
payload = jwt.decode(
142145
token,
143146
key,
@@ -163,6 +166,13 @@ def validate_token(
163166
)
164167
return payload
165168

169+
@property
170+
def oidc_config(self) -> OidcService:
171+
"""Get the OIDC configuration."""
172+
if not self._oidc_config:
173+
self._oidc_config = OidcService(oidc_config_url=self.oidc_config_url)
174+
return self._oidc_config
175+
166176

167177
class OidcFetchError(Exception):
168178
"""Error fetching OIDC configuration."""

0 commit comments

Comments
 (0)