Skip to content

Commit 58c73ee

Browse files
committed
Merge branch 'main' into examples/oidc-docker-compose
2 parents 7d48ed7 + b44d28a commit 58c73ee

File tree

5 files changed

+209
-49
lines changed

5 files changed

+209
-49
lines changed

src/stac_auth_proxy/config.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
"""Configuration for the STAC Auth Proxy."""
22

33
import importlib
4-
from typing import Literal, Optional, Sequence, TypeAlias
4+
from typing import Literal, Optional, Sequence, TypeAlias, Union
55

66
from pydantic import BaseModel, Field
77
from pydantic.networks import HttpUrl
88
from pydantic_settings import BaseSettings, SettingsConfigDict
99

10+
METHODS = Literal["GET", "POST", "PUT", "DELETE", "PATCH"]
11+
EndpointMethodsNoScope: TypeAlias = dict[str, Sequence[METHODS]]
1012
EndpointMethods: TypeAlias = dict[
11-
str, list[Literal["GET", "POST", "PUT", "DELETE", "PATCH"]]
13+
str, Sequence[Union[METHODS, tuple[METHODS, Sequence[str]]]]
1214
]
15+
1316
_PREFIX_PATTERN = r"^/.*$"
1417

1518

@@ -44,7 +47,7 @@ class Settings(BaseSettings):
4447

4548
# Auth
4649
default_public: bool = False
47-
public_endpoints: EndpointMethods = {
50+
public_endpoints: EndpointMethodsNoScope = {
4851
r"^/api.html$": ["GET"],
4952
r"^/api$": ["GET"],
5053
r"^/healthz": ["GET"],

src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from starlette.types import ASGIApp, Receive, Scope, Send
1313

1414
from ..config import EndpointMethods
15-
from ..utils.requests import matches_route
15+
from ..utils.requests import find_match
1616

1717
logger = logging.getLogger(__name__)
1818

@@ -68,11 +68,20 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
6868
return await self.app(scope, receive, send)
6969

7070
request = Request(scope)
71+
match = find_match(
72+
request.url.path,
73+
request.method,
74+
private_endpoints=self.private_endpoints,
75+
public_endpoints=self.public_endpoints,
76+
default_public=self.default_public,
77+
)
7178
try:
7279
payload = self.validate_token(
7380
request.headers.get("Authorization"),
74-
auto_error=self.should_enforce_auth(request),
81+
auto_error=match.is_private,
82+
required_scopes=match.required_scopes,
7583
)
84+
7685
except HTTPException as e:
7786
response = JSONResponse({"detail": e.detail}, status_code=e.status_code)
7887
return await response(scope, receive, send)
@@ -85,18 +94,11 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
8594
)
8695
return await self.app(scope, receive, send)
8796

88-
def should_enforce_auth(self, request: Request) -> bool:
89-
"""Determine if authentication should be required on a given request."""
90-
# If default_public, we only enforce auth if the request is for an endpoint explicitly listed as private
91-
if self.default_public:
92-
return matches_route(request, self.private_endpoints)
93-
# If not default_public, we enforce auth if the request is not for an endpoint explicitly listed as public
94-
return not matches_route(request, self.public_endpoints)
95-
9697
def validate_token(
9798
self,
9899
auth_header: Annotated[str, Security(...)],
99100
auto_error: bool = True,
101+
required_scopes: Optional[Sequence[str]] = None,
100102
) -> Optional[dict[str, Any]]:
101103
"""Dependency to validate an OIDC token."""
102104
if not auth_header:
@@ -136,6 +138,14 @@ def validate_token(
136138
headers={"WWW-Authenticate": "Bearer"},
137139
) from e
138140

141+
if required_scopes:
142+
for scope in required_scopes:
143+
if scope not in payload["scope"].split(" "):
144+
raise HTTPException(
145+
status_code=status.HTTP_401_UNAUTHORIZED,
146+
detail="Not enough permissions",
147+
headers={"WWW-Authenticate": f'Bearer scope="{scope}"'},
148+
)
139149
return payload
140150

141151

src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import gzip
44
import json
5-
import re
65
import zlib
76
from dataclasses import dataclass
87
from typing import Any, Optional
@@ -13,7 +12,7 @@
1312
from starlette.types import ASGIApp, Message, Receive, Scope, Send
1413

1514
from ..config import EndpointMethods
16-
from ..utils.requests import dict_to_bytes
15+
from ..utils.requests import dict_to_bytes, find_match
1716

1817
ENCODING_HANDLERS = {
1918
"gzip": gzip,
@@ -112,24 +111,15 @@ def augment_spec(self, openapi_spec) -> dict[str, Any]:
112111
}
113112
for path, method_config in openapi_spec["paths"].items():
114113
for method, config in method_config.items():
115-
requires_auth = (
116-
self.path_matches(path, method, self.private_endpoints)
117-
if self.default_public
118-
else not self.path_matches(path, method, self.public_endpoints)
114+
match = find_match(
115+
path,
116+
method,
117+
self.private_endpoints,
118+
self.public_endpoints,
119+
self.default_public,
119120
)
120-
if requires_auth:
121+
if match.is_private:
121122
config.setdefault("security", []).append(
122-
{self.oidc_auth_scheme_name: []}
123+
{self.oidc_auth_scheme_name: match.required_scopes}
123124
)
124125
return openapi_spec
125-
126-
@staticmethod
127-
def path_matches(path: str, method: str, endpoints: EndpointMethods) -> bool:
128-
"""Check if the given path and method match any of the regex patterns and methods in the endpoints."""
129-
for pattern, endpoint_methods in endpoints.items():
130-
if not re.match(pattern, path):
131-
continue
132-
for endpoint_method in endpoint_methods:
133-
if method.casefold() == endpoint_method.casefold():
134-
return True
135-
return False

src/stac_auth_proxy/utils/requests.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
import json
44
import re
5+
from dataclasses import dataclass, field
6+
from typing import Sequence
57
from urllib.parse import urlparse
68

7-
from starlette.requests import Request
8-
99
from ..config import EndpointMethods
1010

1111

@@ -26,18 +26,35 @@ def dict_to_bytes(d: dict) -> bytes:
2626
return json.dumps(d, separators=(",", ":")).encode("utf-8")
2727

2828

29-
def matches_route(request: Request, url_patterns: EndpointMethods) -> bool:
30-
"""
31-
Test if the incoming request.path and request.method match any of the patterns
32-
(and their methods) in url_patterns.
33-
"""
34-
path = request.url.path # e.g. '/collections/123'
35-
method = request.method.casefold() # e.g. 'post'
36-
37-
for pattern, allowed_methods in url_patterns.items():
38-
if re.match(pattern, path) and method in [
39-
m.casefold() for m in allowed_methods
40-
]:
41-
return True
42-
43-
return False
29+
def find_match(
30+
path: str,
31+
method: str,
32+
private_endpoints: EndpointMethods,
33+
public_endpoints: EndpointMethods,
34+
default_public: bool,
35+
) -> "MatchResult":
36+
"""Check if the given path and method match any of the regex patterns and methods in the endpoints."""
37+
endpoints = private_endpoints if default_public else public_endpoints
38+
for pattern, endpoint_methods in endpoints.items():
39+
if not re.match(pattern, path):
40+
continue
41+
for endpoint_method in endpoint_methods:
42+
required_scopes: Sequence[str] = []
43+
if isinstance(endpoint_method, tuple):
44+
endpoint_method, required_scopes = endpoint_method
45+
if method.casefold() == endpoint_method.casefold():
46+
# If default_public, we're looking for a private endpoint.
47+
# If not default_public, we're looking for a public endpoint.
48+
return MatchResult(
49+
is_private=default_public,
50+
required_scopes=required_scopes,
51+
)
52+
return MatchResult(is_private=not default_public)
53+
54+
55+
@dataclass
56+
class MatchResult:
57+
"""Result of a match between a path and method and a set of endpoints."""
58+
59+
is_private: bool
60+
required_scopes: Sequence[str] = field(default_factory=list)

tests/test_authn.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,143 @@ def test_default_public_false(source_api_server, path, method, token_builder):
4848
method=method, url=path, headers={"Authorization": f"Bearer {valid_auth_token}"}
4949
)
5050
assert response.status_code == 200
51+
52+
53+
@pytest.mark.parametrize(
54+
"token_scopes, private_endpoints, path, method, expected_permitted",
55+
[
56+
pytest.param(
57+
"",
58+
{r"^/*": [("POST", ["collections:create"])]},
59+
"/collections",
60+
"POST",
61+
False,
62+
id="empty scopes + private endpoint",
63+
),
64+
pytest.param(
65+
"openid profile collections:createbutnotcreate",
66+
{r"^/*": [("POST", ["collections:create"])]},
67+
"/collections",
68+
"POST",
69+
False,
70+
id="invalid scopes + private endpoint",
71+
),
72+
pytest.param(
73+
"openid profile collections:create somethingelse",
74+
{r"^/*": [("POST", [])]},
75+
"/collections",
76+
"POST",
77+
True,
78+
id="valid scopes + private endpoint without required scopes",
79+
),
80+
pytest.param(
81+
"openid",
82+
{r"^/collections/.*/items$": [("POST", ["collections:create"])]},
83+
"/collections",
84+
"GET",
85+
True,
86+
id="accessing public endpoint with private endpoint required scopes",
87+
),
88+
],
89+
)
90+
def test_scopes(
91+
source_api_server,
92+
token_builder,
93+
token_scopes,
94+
private_endpoints,
95+
path,
96+
method,
97+
expected_permitted,
98+
):
99+
"""Private endpoints permit access with a valid token."""
100+
test_app = app_factory(
101+
upstream_url=source_api_server,
102+
default_public=True,
103+
private_endpoints=private_endpoints,
104+
)
105+
valid_auth_token = token_builder({"scope": token_scopes})
106+
client = TestClient(test_app)
107+
108+
response = client.request(
109+
method=method,
110+
url=path,
111+
headers={"Authorization": f"Bearer {valid_auth_token}"},
112+
)
113+
expected_status_code = 200 if expected_permitted else 401
114+
assert response.status_code == expected_status_code
115+
116+
117+
# @pytest.mark.parametrize(
118+
# "is_valid, path, method",
119+
# [
120+
# *[
121+
# [True, *endpoint_method]
122+
# for endpoint_method in [
123+
# ["/collections", "POST"],
124+
# ["/collections/foo", "PUT"],
125+
# ["/collections/foo", "PATCH"],
126+
# ["/collections/foo/items", "POST"],
127+
# ["/collections/foo/items/bar", "PUT"],
128+
# ["/collections/foo/items/bar", "PATCH"],
129+
# ]
130+
# ],
131+
# *[
132+
# [False, *endpoint_method]
133+
# for endpoint_method in [
134+
# ["/collections/foo", "DELETE"],
135+
# ["/collections/foo/items/bar", "DELETE"],
136+
# ]
137+
# ],
138+
# ],
139+
# )
140+
# def test_scopes(source_api_server, token_builder, is_valid, path, method):
141+
# """Private endpoints permit access with a valid token."""
142+
# test_app = app_factory(
143+
# upstream_url=source_api_server,
144+
# default_public=True,
145+
# private_endpoints={
146+
# r"^/collections$": [
147+
# ("POST", ["collections:create"]),
148+
# ],
149+
# r"^/collections/([^/]+)$": [
150+
# # ("PUT", ["collections:update"]),
151+
# # ("PATCH", ["collections:update"]),
152+
# ("DELETE", ["collections:delete"]),
153+
# ],
154+
# r"^/collections/([^/]+)/items$": [
155+
# ("POST", ["items:create"]),
156+
# ],
157+
# r"^/collections/([^/]+)/items/([^/]+)$": [
158+
# # ("PUT", ["items:update"]),
159+
# # ("PATCH", ["items:update"]),
160+
# ("DELETE", ["items:delete"]),
161+
# ],
162+
# r"^/collections/([^/]+)/bulk_items$": [
163+
# ("POST", ["items:create"]),
164+
# ],
165+
# },
166+
# )
167+
# valid_auth_token = token_builder(
168+
# {
169+
# "scopes": " ".join(
170+
# [
171+
# "collection:create",
172+
# "items:create",
173+
# "collections:update",
174+
# "items:update",
175+
# ]
176+
# )
177+
# }
178+
# )
179+
# client = TestClient(test_app)
180+
181+
# response = client.request(
182+
# method=method,
183+
# url=path,
184+
# headers={"Authorization": f"Bearer {valid_auth_token}"},
185+
# json={} if method != "DELETE" else None,
186+
# )
187+
# if is_valid:
188+
# assert response.status_code == 200
189+
# else:
190+
# assert response.status_code == 403

0 commit comments

Comments
 (0)