Skip to content

Commit 6dfa54d

Browse files
committed
fix: check private endpoint scopes when default_public=False
1 parent 48bdbe4 commit 6dfa54d

File tree

2 files changed

+70
-21
lines changed

2 files changed

+70
-21
lines changed

src/stac_auth_proxy/utils/requests.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,23 @@ def dict_to_bytes(d: dict) -> bytes:
2626
return json.dumps(d, separators=(",", ":")).encode("utf-8")
2727

2828

29+
def _check_endpoint_match(
30+
path: str,
31+
method: str,
32+
endpoints: EndpointMethods,
33+
) -> tuple[bool, Sequence[str]]:
34+
"""Check if the path and method match any endpoint in the given endpoints map."""
35+
for pattern, endpoint_methods in endpoints.items():
36+
if re.match(pattern, path):
37+
for endpoint_method in endpoint_methods:
38+
required_scopes: Sequence[str] = []
39+
if isinstance(endpoint_method, tuple):
40+
endpoint_method, required_scopes = endpoint_method
41+
if method.casefold() == endpoint_method.casefold():
42+
return True, required_scopes
43+
return False, []
44+
45+
2946
def find_match(
3047
path: str,
3148
method: str,
@@ -34,22 +51,25 @@ def find_match(
3451
default_public: bool,
3552
) -> "MatchResult":
3653
"""Check if the given path and method match any of the regex patterns and methods in the endpoints."""
37-
endpoints = private_endpoints if default_public else public_endpoints
38-
for pattern, endpoint_methods in endpoints.items():
39-
if not re.match(pattern, path):
40-
continue
41-
for endpoint_method in endpoint_methods:
42-
required_scopes: Sequence[str] = []
43-
if isinstance(endpoint_method, tuple):
44-
endpoint_method, required_scopes = endpoint_method
45-
if method.casefold() == endpoint_method.casefold():
46-
# If default_public, we're looking for a private endpoint.
47-
# If not default_public, we're looking for a public endpoint.
48-
return MatchResult(
49-
is_private=default_public,
50-
required_scopes=required_scopes,
51-
)
52-
return MatchResult(is_private=not default_public)
54+
primary_endpoints = private_endpoints if default_public else public_endpoints
55+
matched, required_scopes = _check_endpoint_match(path, method, primary_endpoints)
56+
if matched:
57+
return MatchResult(
58+
is_private=default_public,
59+
required_scopes=required_scopes,
60+
)
61+
62+
# If default_public and no match found in private_endpoints, it's public
63+
if default_public:
64+
return MatchResult(is_private=False)
65+
66+
# If not default_public, check private_endpoints for required scopes
67+
matched, required_scopes = _check_endpoint_match(path, method, private_endpoints)
68+
if matched:
69+
return MatchResult(is_private=True, required_scopes=required_scopes)
70+
71+
# Default case: if not default_public and no explicit match, it's private
72+
return MatchResult(is_private=True)
5373

5474

5575
@dataclass

tests/test_authn.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,27 +50,56 @@ def test_default_public_false(source_api_server, path, method, token_builder):
5050
assert response.status_code == 200
5151

5252

53+
@pytest.mark.parametrize(
54+
"token,permitted",
55+
[
56+
[{"scope": "collection:create"}, True],
57+
[{"scope": ""}, False],
58+
[{"scope": "openid"}, False],
59+
[{"scope": "openid collection:create"}, True],
60+
],
61+
)
62+
def test_default_public_false_with_scopes(
63+
source_api_server, token, permitted, token_builder
64+
):
65+
"""Private endpoints permit access with a valid token."""
66+
test_app = app_factory(
67+
upstream_url=source_api_server,
68+
default_public=False,
69+
private_endpoints={r"^/collections$": [("POST", ["collection:create"])]},
70+
)
71+
valid_auth_token = token_builder(token)
72+
73+
client = TestClient(test_app)
74+
response = client.request(
75+
method="POST",
76+
url="/collections",
77+
headers={"Authorization": f"Bearer {valid_auth_token}"},
78+
)
79+
assert response.status_code == (200 if permitted else 401)
80+
81+
5382
@pytest.mark.parametrize(
5483
"token_scopes, private_endpoints, path, method, expected_permitted",
5584
[
5685
pytest.param(
5786
"",
58-
{r"^/*": [("POST", ["collections:create"])]},
87+
{r"^/*": [("POST", ["collection:create"])]},
5988
"/collections",
6089
"POST",
6190
False,
6291
id="empty scopes + private endpoint",
6392
),
6493
pytest.param(
65-
"openid profile collections:createbutnotcreate",
66-
{r"^/*": [("POST", ["collections:create"])]},
94+
"openid profile collection:createbutnotcreate",
95+
{r"^/*": [("POST", ["collection:create"])]},
6796
"/collections",
6897
"POST",
6998
False,
7099
id="invalid scopes + private endpoint",
71100
),
72101
pytest.param(
73-
"openid profile collections:create somethingelse",
102+
"openid profile collection:create somethingelse",
74103
{r"^/*": [("POST", [])]},
75104
"/collections",
76105
"POST",
@@ -79,7 +108,7 @@ def test_default_public_false(source_api_server, path, method, token_builder):
79108
),
80109
pytest.param(
81110
"openid",
82-
{r"^/collections/.*/items$": [("POST", ["collections:create"])]},
111+
{r"^/collections/.*/items$": [("POST", ["collection:create"])]},
83112
"/collections",
84113
"GET",
85114
True,

0 commit comments

Comments
 (0)