Skip to content

Commit b5ecada

Browse files
committed
feat: validate auth tokens
1 parent 22696b8 commit b5ecada

File tree

5 files changed

+157
-7
lines changed

5 files changed

+157
-7
lines changed

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ classifiers = [
88
dependencies = [
99
"authlib>=1.3.2",
1010
"brotli>=1.1.0",
11+
"eoapi-auth-utils>=0.4.0",
1112
"fastapi>=0.115.5",
1213
"httpx>=0.28.0",
1314
"pydantic-settings>=2.6.1",
@@ -29,7 +30,7 @@ known_first_party = ["stac_auth_proxy"]
2930
profile = "black"
3031

3132
[tool.ruff]
32-
ignore = ["E501", "D205", "D212"]
33+
ignore = ["E501", "D203", "D205", "D212"]
3334
select = ["D", "E", "F"]
3435

3536
[build-system]
@@ -38,6 +39,7 @@ requires = ["hatchling>=1.12.0"]
3839

3940
[dependency-groups]
4041
dev = [
42+
"jwcrypto>=1.5.6",
4143
"pre-commit>=3.5.0",
4244
"pytest>=8.3.3",
4345
]

src/stac_auth_proxy/app.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
from typing import Optional
99

10+
from eoapi.auth_utils import OpenIdConnectAuth
1011
from fastapi import Depends, FastAPI
11-
from fastapi.security import OpenIdConnect
1212

1313
from .config import Settings
1414
from .handlers import OpenApiSpecHandler
@@ -21,13 +21,12 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
2121
settings = settings or Settings()
2222

2323
app = FastAPI(openapi_url=None)
24+
2425
app.add_middleware(AddProcessTimeHeaderMiddleware)
2526

26-
auth_scheme = OpenIdConnect(
27-
openIdConnectUrl=str(settings.oidc_discovery_url),
28-
scheme_name="OpenID Connect",
29-
description="OpenID Connect authentication for STAC API access",
30-
)
27+
auth_scheme = OpenIdConnectAuth(
28+
openid_configuration_url=str(settings.oidc_discovery_url)
29+
).valid_token_dependency
3130

3231
proxy = ReverseProxy(upstream=str(settings.upstream_url))
3332

tests/conftest.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,65 @@
11
"""Pytest fixtures."""
22

3+
import json
34
import threading
5+
from typing import Any
6+
from unittest.mock import MagicMock, patch
47

58
import pytest
69
import uvicorn
710
from fastapi import FastAPI
11+
from jwcrypto import jwk, jwt
12+
13+
14+
@pytest.fixture
15+
def test_key() -> jwk.JWK:
16+
"""Generate a test RSA key."""
17+
return jwk.JWK.generate(
18+
kty="RSA", size=2048, kid="test", use="sig", e="AQAB", alg="RS256"
19+
)
20+
21+
22+
@pytest.fixture
23+
def public_key(test_key: jwk.JWK) -> dict[str, Any]:
24+
"""Export public key."""
25+
return test_key.export_public(as_dict=True)
26+
27+
28+
@pytest.fixture(autouse=True)
29+
def mock_jwks(public_key: dict[str, Any]):
30+
"""Mock JWKS endpoint."""
31+
mock_oidc_config = {"jwks_uri": "https://example.com/jwks"}
32+
33+
mock_jwks = {"keys": [public_key]}
34+
35+
with (
36+
patch("urllib.request.urlopen") as mock_urlopen,
37+
patch("jwt.PyJWKClient.fetch_data") as mock_fetch_data,
38+
):
39+
mock_oidc_config_response = MagicMock()
40+
mock_oidc_config_response.read.return_value = json.dumps(
41+
mock_oidc_config
42+
).encode()
43+
mock_oidc_config_response.status = 200
44+
45+
mock_urlopen.return_value.__enter__.return_value = mock_oidc_config_response
46+
mock_fetch_data.return_value = mock_jwks
47+
yield mock_urlopen
48+
49+
50+
@pytest.fixture
51+
def token_builder(test_key: jwk.JWK):
52+
"""Generate a valid JWT token builder."""
53+
54+
def build_token(payload: dict[str, Any], key=None) -> str:
55+
jwt_token = jwt.JWT(
56+
header={k: test_key.get(k) for k in ["alg", "kid"]},
57+
claims=payload,
58+
)
59+
jwt_token.make_signed_token(key or test_key)
60+
return jwt_token.serialize()
61+
62+
return build_token
863

964

1065
@pytest.fixture(scope="session")

tests/test_authn.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""Test authentication cases for the proxy app."""
2+
3+
import pytest
4+
from fastapi.testclient import TestClient
5+
6+
from utils import AppFactory
7+
8+
9+
app_factory = AppFactory(
10+
oidc_discovery_url="https://samples.auth0.com/.well-known/openid-configuration",
11+
default_public=False,
12+
public_endpoints={},
13+
private_endpoints={},
14+
)
15+
16+
17+
@pytest.mark.parametrize(
18+
"path,method",
19+
[
20+
("/", "GET"),
21+
("/conformance", "GET"),
22+
("/queryables", "GET"),
23+
("/search", "GET"),
24+
("/search", "POST"),
25+
("/collections", "GET"),
26+
("/collections", "POST"),
27+
("/collections/example-collection", "GET"),
28+
("/collections/example-collection", "PUT"),
29+
("/collections/example-collection", "DELETE"),
30+
("/collections/example-collection/items", "GET"),
31+
("/collections/example-collection/items", "POST"),
32+
("/collections/example-collection/items/example-item", "GET"),
33+
("/collections/example-collection/items/example-item", "PUT"),
34+
("/collections/example-collection/items/example-item", "DELETE"),
35+
("/collections/example-collection/bulk_items", "POST"),
36+
("/api.html", "GET"),
37+
("/api", "GET"),
38+
],
39+
)
40+
def test_default_public_false(source_api_server, path, method, token_builder):
41+
"""
42+
Private endpoints permit access with a valid token.
43+
"""
44+
test_app = app_factory(upstream_url=source_api_server)
45+
valid_auth_token = token_builder({})
46+
47+
client = TestClient(test_app)
48+
response = client.request(method=method, url=path, headers={})
49+
assert response.status_code == 403
50+
51+
response = client.request(
52+
method=method, url=path, headers={"Authorization": f"Bearer {valid_auth_token}"}
53+
)
54+
assert response.status_code == 200

uv.lock

Lines changed: 40 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)