Skip to content

Commit ab01eca

Browse files
committed
Mv auth to middleware, rework url patterns to regex
1 parent 4c73880 commit ab01eca

File tree

5 files changed

+180
-19
lines changed

5 files changed

+180
-19
lines changed

src/stac_auth_proxy/app.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,13 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
3232
)
3333

3434
app.add_middleware(AddProcessTimeHeaderMiddleware)
35-
app.add_middleware(EnforceAuthMiddleware)
35+
app.add_middleware(
36+
EnforceAuthMiddleware,
37+
public_endpoints=settings.public_endpoints,
38+
private_endpoints=settings.private_endpoints,
39+
default_public=settings.default_public,
40+
oidc_config_url=settings.oidc_discovery_url,
41+
)
3642

3743
if settings.openapi_spec_endpoint:
3844
app.add_middleware(

src/stac_auth_proxy/config.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,15 @@ class Settings(BaseSettings):
3737
default_public: bool = False
3838
private_endpoints: EndpointMethods = {
3939
# https://github.com/stac-api-extensions/collection-transaction/blob/v1.0.0-beta.1/README.md#methods
40-
"/collections": ["POST"],
41-
"/collections/{collection_id}": ["PUT", "PATCH", "DELETE"],
40+
r"^/collections$": ["POST"],
41+
r"^/collections/([^/]+)$": ["PUT", "PATCH", "DELETE"],
4242
# https://github.com/stac-api-extensions/transaction/blob/v1.0.0-rc.3/README.md#methods
43-
"/collections/{collection_id}/items": ["POST"],
44-
"/collections/{collection_id}/items/{item_id}": ["PUT", "PATCH", "DELETE"],
43+
r"^/collections/([^/]+)/items$": ["POST"],
44+
r"^/collections/([^/]+)/items/([^/]+)$": ["PUT", "PATCH", "DELETE"],
4545
# https://stac-utils.github.io/stac-fastapi/api/stac_fastapi/extensions/third_party/bulk_transactions/#bulktransactionextension
46-
"/collections/{collection_id}/bulk_items": ["POST"],
46+
r"^/collections/([^/]+)/bulk_items$": ["POST"],
4747
}
48-
public_endpoints: EndpointMethods = {"/api.html": ["GET"], "/api": ["GET"]}
48+
public_endpoints: EndpointMethods = {r"^/api.html$": ["GET"], r"^/api$": ["GET"]}
4949
openapi_spec_endpoint: Optional[str] = None
5050

5151
collections_filter: Optional[ClassInput] = None
Lines changed: 142 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,147 @@
1-
# TODO
2-
from fastapi import Request, Response
3-
from starlette.middleware.base import BaseHTTPMiddleware
1+
from dataclasses import dataclass, field
2+
from typing import Annotated, Optional, Sequence
3+
import json
4+
import logging
5+
import urllib.request
46

7+
from fastapi import HTTPException, Security, security, status, Request
8+
from fastapi.security.base import SecurityBase
9+
from pydantic import HttpUrl
10+
from starlette.middleware.base import ASGIApp
11+
from starlette.responses import JSONResponse
12+
from starlette.types import ASGIApp, Message, Receive, Scope, Send
13+
import jwt
514

6-
class EnforceAuthMiddleware(BaseHTTPMiddleware):
15+
from ..config import EndpointMethods
16+
from ..utils.requests import matches_route
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
@dataclass
22+
class EnforceAuthMiddleware:
723
"""Middleware to enforce authentication."""
824

9-
async def dispatch(self, request: Request, call_next) -> Response:
25+
app: ASGIApp
26+
private_endpoints: EndpointMethods
27+
public_endpoints: EndpointMethods
28+
default_public: bool
29+
30+
oidc_config_url: HttpUrl
31+
openid_configuration_internal_url: Optional[HttpUrl] = None
32+
allowed_jwt_audiences: Optional[Sequence[str]] = None
33+
34+
# Generated attributes
35+
# auth_scheme: SecurityBase = field(init=False)
36+
jwks_client: jwt.PyJWKClient = field(init=False)
37+
38+
def __post_init__(self):
39+
"""Initialize the OIDC authentication class."""
40+
logger.debug("Requesting OIDC config")
41+
origin_url = str(self.openid_configuration_internal_url or self.oidc_config_url)
42+
with urllib.request.urlopen(origin_url) as response:
43+
if response.status != 200:
44+
logger.error(
45+
"Received a non-200 response when fetching OIDC config: %s",
46+
response.text,
47+
)
48+
raise OidcFetchError(
49+
f"Request for OIDC config failed with status {response.status}"
50+
)
51+
oidc_config = json.load(response)
52+
self.jwks_client = jwt.PyJWKClient(oidc_config["jwks_uri"])
53+
54+
# self.auth_scheme = security.OpenIdConnect(
55+
# openIdConnectUrl=str(self.oidc_config_url),
56+
# auto_error=False,
57+
# )
58+
59+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
1060
"""Enforce authentication."""
11-
return await call_next(request)
61+
if scope["type"] != "http":
62+
return await self.app(scope, receive, send)
63+
64+
request = Request(scope)
65+
try:
66+
scope["user"] = self.validated_user(
67+
request.headers.get("Authorization"),
68+
security.SecurityScopes(scopes=["read"]),
69+
auto_error=self.should_enforce_auth(request),
70+
)
71+
except HTTPException as e:
72+
response = JSONResponse({"detail": e.detail}, status_code=e.status_code)
73+
return await response(scope, receive, send)
74+
return await self.app(scope, receive, send)
75+
76+
def should_enforce_auth(self, request: Request) -> bool:
77+
"""Determine if authentication should be required on a given request."""
78+
# If default_public, we only enforce auth if the request is for an endpoint explicitly listed as private
79+
if self.default_public:
80+
return matches_route(request, self.private_endpoints)
81+
# If not default_public, we enforce auth if the request is not for an endpoint explicitly listed as public
82+
return not matches_route(request, self.public_endpoints)
83+
84+
def validated_user(
85+
self,
86+
auth_header: Annotated[str, Security(...)],
87+
required_scopes: security.SecurityScopes,
88+
auto_error: bool = True,
89+
):
90+
"""Dependency to validate an OIDC token."""
91+
if not auth_header:
92+
if auto_error:
93+
raise HTTPException(
94+
status_code=status.HTTP_403_FORBIDDEN,
95+
detail="Not authenticated",
96+
)
97+
return None
98+
99+
# Extract token from header
100+
token_parts = auth_header.split(" ")
101+
if len(token_parts) != 2 or token_parts[0].lower() != "bearer":
102+
logger.error(f"Invalid token: {auth_header}")
103+
raise HTTPException(
104+
status_code=status.HTTP_401_UNAUTHORIZED,
105+
detail="Could not validate credentials",
106+
headers={"WWW-Authenticate": "Bearer"},
107+
)
108+
[_, token] = token_parts
109+
110+
# Parse & validate token
111+
try:
112+
key = self.jwks_client.get_signing_key_from_jwt(token).key
113+
payload = jwt.decode(
114+
token,
115+
key,
116+
algorithms=["RS256"],
117+
# NOTE: Audience validation MUST match audience claim if set in token (https://pyjwt.readthedocs.io/en/stable/changelog.html?highlight=audience#id40)
118+
audience=self.allowed_jwt_audiences,
119+
)
120+
except (jwt.exceptions.InvalidTokenError, jwt.exceptions.DecodeError) as e:
121+
logger.exception(f"InvalidTokenError: {e=}")
122+
raise HTTPException(
123+
status_code=status.HTTP_401_UNAUTHORIZED,
124+
detail="Could not validate credentials",
125+
headers={"WWW-Authenticate": "Bearer"},
126+
) from e
127+
128+
# Validate scopes (if required)
129+
for scope in required_scopes.scopes:
130+
if scope not in payload["scope"]:
131+
if auto_error:
132+
raise HTTPException(
133+
status_code=status.HTTP_401_UNAUTHORIZED,
134+
detail="Not enough permissions",
135+
headers={
136+
"WWW-Authenticate": f'Bearer scope="{required_scopes.scope_str}"'
137+
},
138+
)
139+
return None
140+
141+
return payload
142+
143+
144+
class OidcFetchError(Exception):
145+
"""Error fetching OIDC configuration."""
146+
147+
...

src/stac_auth_proxy/utils/requests.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from urllib.parse import urlparse
66

77
from httpx import Headers
8+
from starlette.requests import Request
89

910

1011
def safe_headers(headers: Headers) -> dict[str, str]:
@@ -35,3 +36,21 @@ def extract_variables(url: str) -> dict:
3536
def dict_to_bytes(d: dict) -> bytes:
3637
"""Convert a dictionary to a body."""
3738
return json.dumps(d, separators=(",", ":")).encode("utf-8")
39+
40+
41+
def matches_route(request: Request, url_patterns: dict[str, list[str]]) -> bool:
42+
"""
43+
Returns True if the incoming request.path and request.method
44+
match any of the patterns (and their methods) in url_patterns.
45+
Otherwise returns False.
46+
"""
47+
path = request.url.path # e.g. '/collections/123'
48+
method = request.method.casefold() # e.g. 'post'
49+
50+
for pattern, allowed_methods in url_patterns.items():
51+
if re.match(pattern, path) and method in [
52+
m.casefold() for m in allowed_methods
53+
]:
54+
return True
55+
56+
return False

tests/test_defaults.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,18 @@
3434
)
3535
def test_default_public_true(source_api_server, path, method, expected_status):
3636
"""
37-
When default_public=true and private_endpoints aren't set, all endpoints should be
37+
When default_public=true and private_endpoints are set, all endpoints should be
3838
public except for transaction endpoints.
3939
"""
4040
test_app = app_factory(
4141
upstream_url=source_api_server,
4242
public_endpoints={},
4343
private_endpoints={
44-
"/collections": ["POST"],
45-
"/collections/{collection_id}": ["PUT", "PATCH", "DELETE"],
46-
"/collections/{collection_id}/items": ["POST"],
47-
"/collections/{collection_id}/items/{item_id}": ["PUT", "PATCH", "DELETE"],
48-
"/collections/{collection_id}/bulk_items": ["POST"],
44+
r"^/collections$": ["POST"],
45+
r"^/collections/([^/]+)$": ["PUT", "PATCH", "DELETE"],
46+
r"^/collections/([^/]+)/items$": ["POST"],
47+
r"^/collections/([^/]+)/items/([^/]+)$": ["PUT", "PATCH", "DELETE"],
48+
r"^/collections/([^/]+)/bulk_items$": ["POST"],
4949
},
5050
default_public=True,
5151
)

0 commit comments

Comments
 (0)