Skip to content

Commit 3a0886a

Browse files
committed
Bring eoapi-auth-utils into this lib, customize to permit optional auth
1 parent 5ad3047 commit 3a0886a

File tree

2 files changed

+120
-2
lines changed

2 files changed

+120
-2
lines changed

src/stac_auth_proxy/app.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
import logging
99
from typing import Optional
1010

11-
from eoapi.auth_utils import OpenIdConnectAuth
1211
from fastapi import Depends, FastAPI
1312

13+
from .auth import OpenIdConnectAuth
1414
from .config import Settings
1515
from .handlers import OpenApiSpecHandler, ReverseProxyHandler
1616
from .middleware import AddProcessTimeHeaderMiddleware
@@ -31,7 +31,8 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
3131

3232
proxy_handler = ReverseProxyHandler(upstream=str(settings.upstream_url))
3333
openapi_handler = OpenApiSpecHandler(
34-
proxy=proxy_handler, oidc_config_url=str(settings.oidc_discovery_url)
34+
proxy=proxy_handler,
35+
oidc_config_url=str(settings.oidc_discovery_url),
3536
)
3637

3738
# Endpoints that are explicitely marked private

src/stac_auth_proxy/auth.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import json
2+
import logging
3+
import urllib.request
4+
from dataclasses import dataclass, field
5+
from typing import Annotated, Any, Callable, Optional, Sequence
6+
7+
import jwt
8+
from fastapi import HTTPException, Security, security, status
9+
from fastapi.security.base import SecurityBase
10+
from starlette.exceptions import HTTPException
11+
from starlette.status import HTTP_403_FORBIDDEN
12+
from pydantic import AnyHttpUrl
13+
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
@dataclass
19+
class OpenIdConnectAuth:
20+
openid_configuration_url: AnyHttpUrl
21+
openid_configuration_internal_url: Optional[AnyHttpUrl] = None
22+
allowed_jwt_audiences: Optional[Sequence[str]] = None
23+
24+
# Generated attributes
25+
auth_scheme: SecurityBase = field(init=False)
26+
jwks_client: jwt.PyJWKClient = field(init=False)
27+
valid_token_dependency: Callable[..., Any] = field(init=False)
28+
29+
def __post_init__(self):
30+
logger.debug("Requesting OIDC config")
31+
with urllib.request.urlopen(
32+
str(self.openid_configuration_internal_url or self.openid_configuration_url)
33+
) as response:
34+
if response.status != 200:
35+
logger.error(
36+
"Received a non-200 response when fetching OIDC config: %s",
37+
response.text,
38+
)
39+
raise OidcFetchError(
40+
f"Request for OIDC config failed with status {response.status}"
41+
)
42+
oidc_config = json.load(response)
43+
self.jwks_client = jwt.PyJWKClient(oidc_config["jwks_uri"])
44+
45+
self.valid_token_dependency.__annotations__["auth_header"] = (
46+
security.OpenIdConnect(
47+
openIdConnectUrl=str(self.openid_configuration_url), auto_error=True
48+
)
49+
)
50+
51+
def user_or_none(self, auth_header: Annotated[str, Security(auth_scheme)]):
52+
"""Return the validated user if authenticated, else None."""
53+
return self.valid_token_dependency(
54+
auth_header, security.SecurityScopes([]), auto_error=False
55+
)
56+
57+
def valid_token_dependency(
58+
self,
59+
auth_header: Annotated[str, Security(auth_scheme)],
60+
required_scopes: security.SecurityScopes,
61+
auto_error: bool = True,
62+
):
63+
"""Dependency to validate an OIDC token."""
64+
if not auth_header:
65+
if auto_error:
66+
raise HTTPException(
67+
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
68+
)
69+
return None
70+
71+
# Extract token from header
72+
token_parts = auth_header.split(" ")
73+
if len(token_parts) != 2 or token_parts[0].lower() != "bearer":
74+
logger.error(f"Invalid token: {auth_header}")
75+
raise HTTPException(
76+
status_code=status.HTTP_401_UNAUTHORIZED,
77+
detail="Could not validate credentials",
78+
headers={"WWW-Authenticate": "Bearer"},
79+
)
80+
[_, token] = token_parts
81+
82+
# Parse & validate token
83+
try:
84+
key = self.jwks_client.get_signing_key_from_jwt(token).key
85+
payload = jwt.decode(
86+
token,
87+
key,
88+
algorithms=["RS256"],
89+
# NOTE: Audience validation MUST match audience claim if set in token (https://pyjwt.readthedocs.io/en/stable/changelog.html?highlight=audience#id40)
90+
audience=self.allowed_jwt_audiences,
91+
)
92+
except (jwt.exceptions.InvalidTokenError, jwt.exceptions.DecodeError) as e:
93+
logger.exception(f"InvalidTokenError: {e=}")
94+
raise HTTPException(
95+
status_code=status.HTTP_401_UNAUTHORIZED,
96+
detail="Could not validate credentials",
97+
headers={"WWW-Authenticate": "Bearer"},
98+
) from e
99+
100+
# Validate scopes (if required)
101+
for scope in required_scopes.scopes:
102+
if scope not in payload["scope"]:
103+
if auto_error:
104+
raise HTTPException(
105+
status_code=status.HTTP_401_UNAUTHORIZED,
106+
detail="Not enough permissions",
107+
headers={
108+
"WWW-Authenticate": f'Bearer scope="{required_scopes.scope_str}"'
109+
},
110+
)
111+
return None
112+
113+
return payload
114+
115+
116+
class OidcFetchError(Exception):
117+
pass

0 commit comments

Comments
 (0)