Skip to content

Commit e7df7c9

Browse files
committed
fix: prevent down oidc from interfering with lifespan
1 parent 2fe0852 commit e7df7c9

File tree

3 files changed

+47
-36
lines changed

3 files changed

+47
-36
lines changed

src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
"""Middleware to enforce authentication."""
22

3-
import json
43
import logging
5-
import urllib.request
6-
from dataclasses import dataclass, field
7-
from typing import Annotated, Optional, Sequence
4+
from dataclasses import dataclass
5+
from typing import Annotated, Any, Optional, Sequence
86

7+
import httpx
98
import jwt
109
from fastapi import HTTPException, Request, Security, status
1110
from pydantic import HttpUrl
@@ -28,29 +27,40 @@ class EnforceAuthMiddleware:
2827
default_public: bool
2928

3029
oidc_config_url: HttpUrl
31-
openid_configuration_internal_url: Optional[HttpUrl] = None
30+
oidc_config_internal_url: Optional[HttpUrl] = None
3231
allowed_jwt_audiences: Optional[Sequence[str]] = None
3332

34-
state_key: str = "user"
33+
state_key: str = "payload"
3534

3635
# Generated attributes
37-
jwks_client: jwt.PyJWKClient = field(init=False)
38-
39-
def __post_init__(self):
40-
"""Initialize the OIDC authentication class."""
41-
logger.debug("Requesting OIDC config")
42-
origin_url = str(self.openid_configuration_internal_url or self.oidc_config_url)
43-
with urllib.request.urlopen(origin_url) as response:
44-
if response.status != 200:
36+
_jwks_client: Optional[jwt.PyJWKClient] = None
37+
38+
@property
39+
def jwks_client(self) -> jwt.PyJWKClient:
40+
"""Get the OIDC configuration URL."""
41+
if not self._jwks_client:
42+
logger.debug("Requesting OIDC config")
43+
origin_url = str(self.oidc_config_internal_url or self.oidc_config_url)
44+
45+
try:
46+
response = httpx.get(origin_url)
47+
response.raise_for_status()
48+
oidc_config = response.json()
49+
self._jwks_client = jwt.PyJWKClient(oidc_config["jwks_uri"])
50+
except httpx.HTTPStatusError as e:
4551
logger.error(
4652
"Received a non-200 response when fetching OIDC config: %s",
47-
response.text,
53+
e.response.text,
4854
)
4955
raise OidcFetchError(
50-
f"Request for OIDC config failed with status {response.status}"
56+
f"Request for OIDC config failed with status {e.response.status_code}"
57+
) from e
58+
except httpx.RequestError as e:
59+
logger.error(
60+
"Error fetching OIDC config from %s: %s", origin_url, str(e)
5161
)
52-
oidc_config = json.load(response)
53-
self.jwks_client = jwt.PyJWKClient(oidc_config["jwks_uri"])
62+
raise OidcFetchError(f"Request for OIDC config failed: {str(e)}") from e
63+
return self._jwks_client
5464

5565
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
5666
"""Enforce authentication."""
@@ -59,17 +69,20 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
5969

6070
request = Request(scope)
6171
try:
62-
setattr(
63-
request.state,
64-
self.state_key,
65-
self.validated_user(
66-
request.headers.get("Authorization"),
67-
auto_error=self.should_enforce_auth(request),
68-
),
72+
payload = self.validate_token(
73+
request.headers.get("Authorization"),
74+
auto_error=self.should_enforce_auth(request),
6975
)
7076
except HTTPException as e:
7177
response = JSONResponse({"detail": e.detail}, status_code=e.status_code)
7278
return await response(scope, receive, send)
79+
80+
# Set the payload in the request state
81+
setattr(
82+
request.state,
83+
self.state_key,
84+
payload,
85+
)
7386
return await self.app(scope, receive, send)
7487

7588
def should_enforce_auth(self, request: Request) -> bool:
@@ -80,11 +93,11 @@ def should_enforce_auth(self, request: Request) -> bool:
8093
# If not default_public, we enforce auth if the request is not for an endpoint explicitly listed as public
8194
return not matches_route(request, self.public_endpoints)
8295

83-
def validated_user(
96+
def validate_token(
8497
self,
8598
auth_header: Annotated[str, Security(...)],
8699
auto_error: bool = True,
87-
):
100+
) -> Optional[dict[str, Any]]:
88101
"""Dependency to validate an OIDC token."""
89102
if not auth_header:
90103
if auto_error:

tests/conftest.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,14 @@ def mock_jwks(public_key: dict[str, Any]):
3636
mock_jwks = {"keys": [public_key]}
3737

3838
with (
39-
patch("urllib.request.urlopen") as mock_urlopen,
39+
patch("httpx.get") as mock_urlopen,
4040
patch("jwt.PyJWKClient.fetch_data") as mock_fetch_data,
4141
):
4242
mock_oidc_config_response = MagicMock()
43-
mock_oidc_config_response.read.return_value = json.dumps(
44-
mock_oidc_config
45-
).encode()
43+
mock_oidc_config_response.json.return_value = mock_oidc_config
4644
mock_oidc_config_response.status = 200
4745

48-
mock_urlopen.return_value.__enter__.return_value = mock_oidc_config_response
46+
mock_urlopen.return_value = mock_oidc_config_response
4947
mock_fetch_data.return_value = mock_jwks
5048
yield mock_urlopen
5149

@@ -121,7 +119,7 @@ def source_api():
121119
return app
122120

123121

124-
@pytest.fixture
122+
@pytest.fixture(scope="session")
125123
def source_api_server(source_api):
126124
"""Run the source API in a background thread."""
127125
host, port = "127.0.0.1", 9119
@@ -139,7 +137,7 @@ def source_api_server(source_api):
139137
thread.join()
140138

141139

142-
@pytest.fixture(autouse=True, scope="module")
140+
@pytest.fixture(autouse=True, scope="session")
143141
def mock_env():
144142
"""Clear environment variables to avoid poluting configs from runtime env."""
145143
with patch.dict(os.environ, clear=True):

tests/test_filters_jinja2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
id="simple_not_templated",
1919
),
2020
pytest.param(
21-
"{{ '(properties.private = false)' if user is none else true }}",
21+
"{{ '(properties.private = false)' if payload is none else true }}",
2222
"true",
2323
"(properties.private = false)",
2424
id="simple_templated",
@@ -30,7 +30,7 @@
3030
id="complex_not_templated",
3131
),
3232
pytest.param(
33-
"""{{ '{"op": "=", "args": [{"property": "private"}, true]}' if user is none else true }}""",
33+
"""{{ '{"op": "=", "args": [{"property": "private"}, true]}' if payload is none else true }}""",
3434
"true",
3535
"""{"op": "=", "args": [{"property": "private"}, true]}""",
3636
id="complex_templated",

0 commit comments

Comments
 (0)