Skip to content

Commit d36dd2e

Browse files
committed
Rework middleware
1 parent 0a0254b commit d36dd2e

File tree

2 files changed

+44
-36
lines changed

2 files changed

+44
-36
lines changed

src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py

Lines changed: 42 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import logging
44
from dataclasses import dataclass, field
5-
from typing import Annotated, Optional, Sequence
5+
from typing import Annotated, Any, Optional, Sequence
66

77
import httpx
88
import jwt
@@ -30,32 +30,37 @@ class EnforceAuthMiddleware:
3030
oidc_config_internal_url: Optional[HttpUrl] = None
3131
allowed_jwt_audiences: Optional[Sequence[str]] = None
3232

33-
state_key: str = "user"
33+
state_key: str = "payload"
3434

3535
# Generated attributes
36-
jwks_client: jwt.PyJWKClient = field(init=False)
37-
38-
def __post_init__(self):
39-
"""Initialize the OIDC authentication class."""
40-
logger.debug("Requesting OIDC config")
41-
origin_url = str(self.oidc_config_internal_url or self.oidc_config_url)
42-
43-
try:
44-
response = httpx.get(origin_url)
45-
response.raise_for_status()
46-
oidc_config = response.json()
47-
self.jwks_client = jwt.PyJWKClient(oidc_config["jwks_uri"])
48-
except httpx.HTTPStatusError as e:
49-
logger.error(
50-
"Received a non-200 response when fetching OIDC config: %s",
51-
e.response.text,
52-
)
53-
raise OidcFetchError(
54-
f"Request for OIDC config failed with status {e.response.status_code}"
55-
)
56-
except httpx.RequestError as e:
57-
logger.error("Error fetching OIDC config from %s: %s", origin_url, str(e))
58-
raise OidcFetchError(f"Request for OIDC config failed: {str(e)}")
36+
_jwks_client: Optional[jwt.PyJWKClient] = None
37+
38+
@property
39+
def jwks_client(self) -> HttpUrl:
40+
"""Get the OIDC configuration URL."""
41+
if not self._jwks_client:
42+
logger.debug("Requesting OIDC config")
43+
origin_url = str(self.oidc_config_internal_url or self.oidc_config_url)
44+
45+
try:
46+
response = httpx.get(origin_url)
47+
response.raise_for_status()
48+
oidc_config = response.json()
49+
self._jwks_client = jwt.PyJWKClient(oidc_config["jwks_uri"])
50+
except httpx.HTTPStatusError as e:
51+
logger.error(
52+
"Received a non-200 response when fetching OIDC config: %s",
53+
e.response.text,
54+
)
55+
raise OidcFetchError(
56+
f"Request for OIDC config failed with status {e.response.status_code}"
57+
) from e
58+
except httpx.RequestError as e:
59+
logger.error(
60+
"Error fetching OIDC config from %s: %s", origin_url, str(e)
61+
)
62+
raise OidcFetchError(f"Request for OIDC config failed: {str(e)}") from e
63+
return self._jwks_client
5964

6065
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
6166
"""Enforce authentication."""
@@ -64,17 +69,20 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
6469

6570
request = Request(scope)
6671
try:
67-
setattr(
68-
request.state,
69-
self.state_key,
70-
self.validated_user(
71-
request.headers.get("Authorization"),
72-
auto_error=self.should_enforce_auth(request),
73-
),
72+
payload = self.validate_token(
73+
request.headers.get("Authorization"),
74+
auto_error=self.should_enforce_auth(request),
7475
)
7576
except HTTPException as e:
7677
response = JSONResponse({"detail": e.detail}, status_code=e.status_code)
7778
return await response(scope, receive, send)
79+
80+
# Set the payload in the request state
81+
setattr(
82+
request.state,
83+
self.state_key,
84+
payload,
85+
)
7886
return await self.app(scope, receive, send)
7987

8088
def should_enforce_auth(self, request: Request) -> bool:
@@ -85,11 +93,11 @@ def should_enforce_auth(self, request: Request) -> bool:
8593
# If not default_public, we enforce auth if the request is not for an endpoint explicitly listed as public
8694
return not matches_route(request, self.public_endpoints)
8795

88-
def validated_user(
96+
def validate_token(
8997
self,
9098
auth_header: Annotated[str, Security(...)],
9199
auto_error: bool = True,
92-
):
100+
) -> Optional[dict[str, Any]]:
93101
"""Dependency to validate an OIDC token."""
94102
if not auth_header:
95103
if auto_error:

tests/test_filters_jinja2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
id="simple_not_templated",
1919
),
2020
pytest.param(
21-
"{{ '(properties.private = false)' if user is none else true }}",
21+
"{{ '(properties.private = false)' if payload is none else true }}",
2222
"true",
2323
"(properties.private = false)",
2424
id="simple_templated",
@@ -30,7 +30,7 @@
3030
id="complex_not_templated",
3131
),
3232
pytest.param(
33-
"""{{ '{"op": "=", "args": [{"property": "private"}, true]}' if user is none else true }}""",
33+
"""{{ '{"op": "=", "args": [{"property": "private"}, true]}' if payload is none else true }}""",
3434
"true",
3535
"""{"op": "=", "args": [{"property": "private"}, true]}""",
3636
id="complex_templated",

0 commit comments

Comments
 (0)