Skip to content

Commit 37fc37d

Browse files
committed
feat: optionally validate scopes
rework path matching
1 parent 13acd0f commit 37fc37d

File tree

5 files changed

+193
-44
lines changed

5 files changed

+193
-44
lines changed

src/stac_auth_proxy/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
"""Configuration for the STAC Auth Proxy."""
22

33
import importlib
4-
from typing import Literal, Optional, Sequence, TypeAlias
4+
from typing import Literal, Optional, Sequence, TypeAlias, Union
55

66
from pydantic import BaseModel, Field
77
from pydantic.networks import HttpUrl
88
from pydantic_settings import BaseSettings, SettingsConfigDict
99

10+
METHODS = Literal["GET", "POST", "PUT", "DELETE", "PATCH"]
1011
EndpointMethods: TypeAlias = dict[
11-
str, list[Literal["GET", "POST", "PUT", "DELETE", "PATCH"]]
12+
str, Sequence[Union[METHODS, tuple[METHODS, Sequence[str]]]]
1213
]
1314
_PREFIX_PATTERN = r"^/.*$"
1415

src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from starlette.types import ASGIApp, Receive, Scope, Send
1313

1414
from ..config import EndpointMethods
15-
from ..utils.requests import matches_route
15+
from ..utils.requests import find_match
1616

1717
logger = logging.getLogger(__name__)
1818

@@ -68,11 +68,20 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
6868
return await self.app(scope, receive, send)
6969

7070
request = Request(scope)
71+
match = find_match(
72+
request.url.path,
73+
request.method,
74+
private_endpoints=self.private_endpoints,
75+
public_endpoints=self.public_endpoints,
76+
default_public=self.default_public,
77+
)
7178
try:
7279
payload = self.validate_token(
7380
request.headers.get("Authorization"),
74-
auto_error=self.should_enforce_auth(request),
81+
auto_error=match["is_private"],
82+
required_scopes=match["scopes"],
7583
)
84+
7685
except HTTPException as e:
7786
response = JSONResponse({"detail": e.detail}, status_code=e.status_code)
7887
return await response(scope, receive, send)
@@ -85,18 +94,11 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
8594
)
8695
return await self.app(scope, receive, send)
8796

88-
def should_enforce_auth(self, request: Request) -> bool:
89-
"""Determine if authentication should be required on a given request."""
90-
# If default_public, we only enforce auth if the request is for an endpoint explicitly listed as private
91-
if self.default_public:
92-
return matches_route(request, self.private_endpoints)
93-
# If not default_public, we enforce auth if the request is not for an endpoint explicitly listed as public
94-
return not matches_route(request, self.public_endpoints)
95-
9697
def validate_token(
9798
self,
9899
auth_header: Annotated[str, Security(...)],
99100
auto_error: bool = True,
101+
required_scopes: Optional[Sequence[str]] = None,
100102
) -> Optional[dict[str, Any]]:
101103
"""Dependency to validate an OIDC token."""
102104
if not auth_header:
@@ -136,6 +138,14 @@ def validate_token(
136138
headers={"WWW-Authenticate": "Bearer"},
137139
) from e
138140

141+
if required_scopes:
142+
for scope in required_scopes:
143+
if scope not in payload["scope"].split(" "):
144+
raise HTTPException(
145+
status_code=status.HTTP_401_UNAUTHORIZED,
146+
detail="Not enough permissions",
147+
headers={"WWW-Authenticate": f'Bearer scope="{scope}"'},
148+
)
139149
return payload
140150

141151

src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from starlette.types import ASGIApp, Message, Receive, Scope, Send
1414

1515
from ..config import EndpointMethods
16-
from ..utils.requests import dict_to_bytes
16+
from ..utils.requests import dict_to_bytes, find_match
1717

1818
ENCODING_HANDLERS = {
1919
"gzip": gzip,
@@ -112,24 +112,15 @@ def augment_spec(self, openapi_spec) -> dict[str, Any]:
112112
}
113113
for path, method_config in openapi_spec["paths"].items():
114114
for method, config in method_config.items():
115-
requires_auth = (
116-
self.path_matches(path, method, self.private_endpoints)
117-
if self.default_public
118-
else not self.path_matches(path, method, self.public_endpoints)
115+
match = find_match(
116+
path,
117+
method,
118+
self.private_endpoints,
119+
self.public_endpoints,
120+
self.default_public,
119121
)
120-
if requires_auth:
122+
if match["is_private"]:
121123
config.setdefault("security", []).append(
122124
{self.oidc_auth_scheme_name: []}
123125
)
124126
return openapi_spec
125-
126-
@staticmethod
127-
def path_matches(path: str, method: str, endpoints: EndpointMethods) -> bool:
128-
"""Check if the given path and method match any of the regex patterns and methods in the endpoints."""
129-
for pattern, endpoint_methods in endpoints.items():
130-
if not re.match(pattern, path):
131-
continue
132-
for endpoint_method in endpoint_methods:
133-
if method.casefold() == endpoint_method.casefold():
134-
return True
135-
return False

src/stac_auth_proxy/utils/requests.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import json
44
import re
5+
from typing import Optional, Sequence
56
from urllib.parse import urlparse
67

78
from starlette.requests import Request
@@ -26,18 +27,24 @@ def dict_to_bytes(d: dict) -> bytes:
2627
return json.dumps(d, separators=(",", ":")).encode("utf-8")
2728

2829

29-
def matches_route(request: Request, url_patterns: EndpointMethods) -> bool:
30-
"""
31-
Test if the incoming request.path and request.method match any of the patterns
32-
(and their methods) in url_patterns.
33-
"""
34-
path = request.url.path # e.g. '/collections/123'
35-
method = request.method.casefold() # e.g. 'post'
36-
37-
for pattern, allowed_methods in url_patterns.items():
38-
if re.match(pattern, path) and method in [
39-
m.casefold() for m in allowed_methods
40-
]:
41-
return True
42-
43-
return False
30+
def find_match(
31+
path: str,
32+
method: str,
33+
private_endpoints: EndpointMethods,
34+
public_endpoints: EndpointMethods,
35+
default_public: bool,
36+
) -> Optional[tuple[str, str, Sequence[str]]]:
37+
"""Check if the given path and method match any of the regex patterns and methods in the endpoints."""
38+
endpoints = private_endpoints if default_public else public_endpoints
39+
for pattern, endpoint_methods in endpoints.items():
40+
if not re.match(pattern, path):
41+
continue
42+
for endpoint_method in endpoint_methods:
43+
required_scopes: Sequence[str] = []
44+
if isinstance(endpoint_method, tuple):
45+
endpoint_method, required_scopes = endpoint_method
46+
if method.casefold() == endpoint_method.casefold():
47+
# If default_public, we're looking for a private endpoint.
48+
# If not default_public, we're looking for a public endpoint.
49+
return {"is_private": default_public, "scopes": required_scopes}
50+
return {"is_private": not default_public, "scopes": []}

tests/test_authn.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,143 @@ def test_default_public_false(source_api_server, path, method, token_builder):
4848
method=method, url=path, headers={"Authorization": f"Bearer {valid_auth_token}"}
4949
)
5050
assert response.status_code == 200
51+
52+
53+
@pytest.mark.parametrize(
54+
"token_scopes, private_endpoints, path, method, expected_permitted",
55+
[
56+
pytest.param(
57+
"",
58+
{r"^/*": [("POST", ["collections:create"])]},
59+
"/collections",
60+
"POST",
61+
False,
62+
id="empty scopes + private endpoint",
63+
),
64+
pytest.param(
65+
"openid profile collections:createbutnotcreate",
66+
{r"^/*": [("POST", ["collections:create"])]},
67+
"/collections",
68+
"POST",
69+
False,
70+
id="invalid scopes + private endpoint",
71+
),
72+
pytest.param(
73+
"openid profile collections:create somethingelse",
74+
{r"^/*": [("POST", [])]},
75+
"/collections",
76+
"POST",
77+
True,
78+
id="valid scopes + private endpoint without required scopes",
79+
),
80+
pytest.param(
81+
"openid",
82+
{r"^/collections/.*/items$": [("POST", ["collections:create"])]},
83+
"/collections",
84+
"GET",
85+
True,
86+
id="accessing public endpoint with private endpoint required scopes",
87+
),
88+
],
89+
)
90+
def test_scopes(
91+
source_api_server,
92+
token_builder,
93+
token_scopes,
94+
private_endpoints,
95+
path,
96+
method,
97+
expected_permitted,
98+
):
99+
"""Private endpoints permit access with a valid token."""
100+
test_app = app_factory(
101+
upstream_url=source_api_server,
102+
default_public=True,
103+
private_endpoints=private_endpoints,
104+
)
105+
valid_auth_token = token_builder({"scope": token_scopes})
106+
client = TestClient(test_app)
107+
108+
response = client.request(
109+
method=method,
110+
url=path,
111+
headers={"Authorization": f"Bearer {valid_auth_token}"},
112+
)
113+
expected_status_code = 200 if expected_permitted else 401
114+
assert response.status_code == expected_status_code
115+
116+
117+
# @pytest.mark.parametrize(
118+
# "is_valid, path, method",
119+
# [
120+
# *[
121+
# [True, *endpoint_method]
122+
# for endpoint_method in [
123+
# ["/collections", "POST"],
124+
# ["/collections/foo", "PUT"],
125+
# ["/collections/foo", "PATCH"],
126+
# ["/collections/foo/items", "POST"],
127+
# ["/collections/foo/items/bar", "PUT"],
128+
# ["/collections/foo/items/bar", "PATCH"],
129+
# ]
130+
# ],
131+
# *[
132+
# [False, *endpoint_method]
133+
# for endpoint_method in [
134+
# ["/collections/foo", "DELETE"],
135+
# ["/collections/foo/items/bar", "DELETE"],
136+
# ]
137+
# ],
138+
# ],
139+
# )
140+
# def test_scopes(source_api_server, token_builder, is_valid, path, method):
141+
# """Private endpoints permit access with a valid token."""
142+
# test_app = app_factory(
143+
# upstream_url=source_api_server,
144+
# default_public=True,
145+
# private_endpoints={
146+
# r"^/collections$": [
147+
# ("POST", ["collections:create"]),
148+
# ],
149+
# r"^/collections/([^/]+)$": [
150+
# # ("PUT", ["collections:update"]),
151+
# # ("PATCH", ["collections:update"]),
152+
# ("DELETE", ["collections:delete"]),
153+
# ],
154+
# r"^/collections/([^/]+)/items$": [
155+
# ("POST", ["items:create"]),
156+
# ],
157+
# r"^/collections/([^/]+)/items/([^/]+)$": [
158+
# # ("PUT", ["items:update"]),
159+
# # ("PATCH", ["items:update"]),
160+
# ("DELETE", ["items:delete"]),
161+
# ],
162+
# r"^/collections/([^/]+)/bulk_items$": [
163+
# ("POST", ["items:create"]),
164+
# ],
165+
# },
166+
# )
167+
# valid_auth_token = token_builder(
168+
# {
169+
# "scopes": " ".join(
170+
# [
171+
# "collection:create",
172+
# "items:create",
173+
# "collections:update",
174+
# "items:update",
175+
# ]
176+
# )
177+
# }
178+
# )
179+
# client = TestClient(test_app)
180+
181+
# response = client.request(
182+
# method=method,
183+
# url=path,
184+
# headers={"Authorization": f"Bearer {valid_auth_token}"},
185+
# json={} if method != "DELETE" else None,
186+
# )
187+
# if is_valid:
188+
# assert response.status_code == 200
189+
# else:
190+
# assert response.status_code == 403

0 commit comments

Comments
 (0)