Skip to content

Commit 32382da

Browse files
committed
More fancy CEL test
1 parent cff2a24 commit 32382da

File tree

2 files changed

+52
-20
lines changed

2 files changed

+52
-20
lines changed

src/stac_auth_proxy/guards/cel.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from dataclasses import dataclass
22
from typing import Any
3+
import re
4+
from urllib.parse import urlparse
35

46
from fastapi import Request, Depends, HTTPException
57
import celpy
@@ -25,6 +27,7 @@ async def check(
2527
"path": request.url.path,
2628
"method": request.method,
2729
"query_params": dict(request.query_params),
30+
"path_params": extract_variables(request.url.path),
2831
"headers": dict(request.headers),
2932
# Body may need to be read (await request.json()) or (await request.body()) if needed
3033
"body": (
@@ -34,6 +37,8 @@ async def check(
3437
),
3538
}
3639

40+
print(f"{request_data['path_params']=}")
41+
3742
result = self.program.evaluate(
3843
celpy.json_to_cel(
3944
{
@@ -48,3 +53,11 @@ async def check(
4853
)
4954

5055
self.check = check
56+
57+
58+
def extract_variables(url: str) -> dict:
59+
path = urlparse(url).path
60+
# This allows either /items or /bulk_items, with an optional item_id following.
61+
pattern = r"^/collections/(?P<collection_id>[^/]+)(?:/(?:items|bulk_items)(?:/(?P<item_id>[^/]+))?)?/?$"
62+
match = re.match(pattern, path)
63+
return {k: v for k, v in match.groupdict().items() if v} if match else {}

tests/test_guard.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,31 +10,15 @@
1010
)
1111

1212

13-
import pytest
14-
from unittest.mock import patch, MagicMock
15-
16-
17-
# Fixture to patch OpenIdConnectAuth and mock valid_token_dependency
18-
@pytest.fixture
19-
def skip_auth():
20-
with patch("eoapi.auth_utils.OpenIdConnectAuth") as MockClass:
21-
# Create a mock instance
22-
mock_instance = MagicMock()
23-
# Set the return value of `valid_token_dependency`
24-
mock_instance.valid_token_dependency.return_value = "constant"
25-
# Assign the mock instance to the patched class's return value
26-
MockClass.return_value = mock_instance
27-
28-
# Yield the mock instance for use in tests
29-
yield mock_instance
30-
31-
3213
@pytest.mark.parametrize(
3314
"endpoint, expected_status_code",
3415
[
3516
("/", 403),
3617
("/?foo=xyz", 403),
18+
("/?bar=foo", 403),
3719
("/?foo=bar", 200),
20+
("/?foo=xyz&foo=bar", 200), # Only the last value is checked
21+
("/?foo=bar&foo=xyz", 403), # Only the last value is checked
3822
],
3923
)
4024
def test_guard_query_params(
@@ -43,7 +27,6 @@ def test_guard_query_params(
4327
endpoint,
4428
expected_status_code,
4529
):
46-
"""When no OpenAPI spec endpoint is set, the proxied OpenAPI spec is unaltered."""
4730
app = app_factory(
4831
upstream_url=source_api_server,
4932
guard={
@@ -56,3 +39,39 @@ def test_guard_query_params(
5639
client = TestClient(app, headers={"Authorization": f"Bearer {token_builder({})}"})
5740
response = client.get(endpoint)
5841
assert response.status_code == expected_status_code
42+
43+
44+
@pytest.mark.parametrize(
45+
"token, expected_status_code",
46+
[
47+
({"foo": "bar"}, 403),
48+
({"collections": []}, 403),
49+
({"collections": ["foo", "bar"]}, 403),
50+
({"collections": ["xyz"]}, 200),
51+
({"collections": ["foo", "xyz"]}, 200),
52+
],
53+
)
54+
def test_guard_auth_token(
55+
source_api_server,
56+
token_builder,
57+
token,
58+
expected_status_code,
59+
):
60+
app = app_factory(
61+
upstream_url=source_api_server,
62+
guard={
63+
"cls": "stac_auth_proxy.guards.cel.Cel",
64+
"kwargs": {
65+
"expression": """
66+
("collections" in token)
67+
&& ("collection_id" in req.path_params)
68+
&& (req.path_params.collection_id in token.collections)
69+
"""
70+
},
71+
},
72+
)
73+
client = TestClient(
74+
app, headers={"Authorization": f"Bearer {token_builder(token)}"}
75+
)
76+
response = client.get("/collections/xyz")
77+
assert response.status_code == expected_status_code

0 commit comments

Comments
 (0)