Skip to content

Commit 9373dee

Browse files
authored
feat: Support CEL access policies (#13)
* Add first pass of using CEL for custom auth policies * More fancy CEL test * Simplify guard pattern * Support args in guards for simpler initialization Reorg tests * Rename variable * Expand docs * pre-commit format * Add docstring for pre-commit
1 parent 50b99e6 commit 9373dee

File tree

10 files changed

+399
-1
lines changed

10 files changed

+399
-1
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
STAC Auth Proxy is a proxy API that mediates between the client and and some internally accessible STAC API in order to provide a flexible authentication mechanism.
44

5+
## Features
6+
7+
- 🔐 Selectively apply OIDC auth to some or all endpoints & methods
8+
- 📖 Augments [OpenAPI](https://swagger.io/specification/) with auth information, keeping auto-generated docs (e.g. [Swagger UI](https://swagger.io/tools/swagger-ui/)) accurate
9+
- 💂‍♀️ Custom policies enforce complex access controls, defined with [Common Expression Language (CEL)](https://cel.dev/)
10+
511
## Installation
612

713
Set up connection to upstream STAC API and the OpenID Connect provider by setting the following environment variables:

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ classifiers = [
88
dependencies = [
99
"authlib>=1.3.2",
1010
"brotli>=1.1.0",
11+
"cel-python>=0.1.5",
1112
"eoapi-auth-utils>=0.4.0",
1213
"fastapi>=0.115.5",
1314
"httpx>=0.28.0",

src/stac_auth_proxy/app.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
authentication, authorization, and proxying of requests to some internal STAC API.
66
"""
77

8+
import logging
89
from typing import Optional
910

1011
from eoapi.auth_utils import OpenIdConnectAuth
@@ -14,6 +15,8 @@
1415
from .handlers import OpenApiSpecHandler, ReverseProxyHandler
1516
from .middleware import AddProcessTimeHeaderMiddleware
1617

18+
logger = logging.getLogger(__name__)
19+
1720

1821
def create_app(settings: Optional[Settings] = None) -> FastAPI:
1922
"""FastAPI Application Factory."""
@@ -26,6 +29,10 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
2629
openid_configuration_url=str(settings.oidc_discovery_url)
2730
).valid_token_dependency
2831

32+
if settings.guard:
33+
logger.info("Wrapping auth scheme")
34+
auth_scheme = settings.guard(auth_scheme)
35+
2936
proxy_handler = ReverseProxyHandler(upstream=str(settings.upstream_url))
3037
openapi_handler = OpenApiSpecHandler(
3138
proxy=proxy_handler, oidc_config_url=str(settings.oidc_discovery_url)

src/stac_auth_proxy/config.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,30 @@
11
"""Configuration for the STAC Auth Proxy."""
22

3-
from typing import Optional, TypeAlias
3+
import importlib
4+
from typing import Optional, Sequence, TypeAlias
45

6+
from pydantic import BaseModel
57
from pydantic.networks import HttpUrl
68
from pydantic_settings import BaseSettings, SettingsConfigDict
79

810
EndpointMethods: TypeAlias = dict[str, list[str]]
911

1012

13+
class ClassInput(BaseModel):
14+
"""Input model for dynamically loading a class or function."""
15+
16+
cls: str
17+
args: Optional[Sequence[str]] = []
18+
kwargs: Optional[dict[str, str]] = {}
19+
20+
def __call__(self, token_dependency):
21+
"""Dynamically load a class and instantiate it with kwargs."""
22+
module_path, class_name = self.cls.rsplit(".", 1)
23+
module = importlib.import_module(module_path)
24+
cls = getattr(module, class_name)
25+
return cls(*self.args, **self.kwargs, token_dependency=token_dependency)
26+
27+
1128
class Settings(BaseSettings):
1229
"""Configuration settings for the STAC Auth Proxy."""
1330

@@ -30,3 +47,5 @@ class Settings(BaseSettings):
3047
openapi_spec_endpoint: Optional[str] = None
3148

3249
model_config = SettingsConfigDict(env_prefix="STAC_AUTH_PROXY_")
50+
51+
guard: Optional[ClassInput] = None
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Module to define the access policy guards for the application."""
2+
3+
from .cel import cel
4+
5+
__all__ = ["cel"]

src/stac_auth_proxy/guards/cel.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""Guard using CEL (Common Expression Language, https://cel.dev)."""
2+
3+
from typing import Any, Callable
4+
5+
import celpy
6+
from fastapi import Depends, HTTPException, Request
7+
8+
from ..utils import extract_variables
9+
10+
11+
def cel(expression: str, token_dependency: Callable[..., Any]):
12+
"""Cel check factory."""
13+
env = celpy.Environment()
14+
ast = env.compile(expression)
15+
program = env.program(ast)
16+
17+
async def check(
18+
request: Request,
19+
auth_token=Depends(token_dependency),
20+
):
21+
request_data = {
22+
"path": request.url.path,
23+
"method": request.method,
24+
"query_params": dict(request.query_params),
25+
"path_params": extract_variables(request.url.path),
26+
"headers": dict(request.headers),
27+
"body": (
28+
await request.json()
29+
if request.headers.get("content-type") == "application/json"
30+
else (await request.body()).decode()
31+
),
32+
}
33+
34+
result = program.evaluate(
35+
celpy.json_to_cel(
36+
{
37+
"req": request_data,
38+
"token": auth_token,
39+
}
40+
)
41+
)
42+
if not result:
43+
raise HTTPException(status_code=403, detail="Forbidden (failed CEL check)")
44+
45+
return check

src/stac_auth_proxy/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
"""Utility functions."""
22

3+
import re
4+
from urllib.parse import urlparse
5+
36
from httpx import Headers
47

58

@@ -14,3 +17,15 @@ def safe_headers(headers: Headers) -> dict[str, str]:
1417
for key, value in headers.items()
1518
if key.lower() not in excluded_headers
1619
}
20+
21+
22+
def extract_variables(url: str) -> dict:
23+
"""
24+
Extract variables from a URL path. Being that we use a catch-all endpoint for the proxy,
25+
we can't rely on the path parameters that FastAPI provides.
26+
"""
27+
path = urlparse(url).path
28+
# This allows either /items or /bulk_items, with an optional item_id following.
29+
pattern = r"^/collections/(?P<collection_id>[^/]+)(?:/(?:items|bulk_items)(?:/(?P<item_id>[^/]+))?)?/?$"
30+
match = re.match(pattern, path)
31+
return {k: v for k, v in match.groupdict().items() if v} if match else {}

tests/test_guards_cel.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""Tests for CEL guard."""
2+
3+
import pytest
4+
from fastapi.testclient import TestClient
5+
from utils import AppFactory
6+
7+
app_factory = AppFactory(
8+
oidc_discovery_url="https://samples.auth0.com/.well-known/openid-configuration",
9+
default_public=False,
10+
)
11+
12+
13+
@pytest.mark.parametrize(
14+
"endpoint, expected_status_code",
15+
[
16+
("/", 403),
17+
("/?foo=xyz", 403),
18+
("/?bar=foo", 403),
19+
("/?foo=bar", 200),
20+
("/?foo=xyz&foo=bar", 200), # Only the last value is checked
21+
("/?foo=bar&foo=xyz", 403), # Only the last value is checked
22+
],
23+
)
24+
def test_guard_query_params(
25+
source_api_server,
26+
token_builder,
27+
endpoint,
28+
expected_status_code,
29+
):
30+
"""Test guard with query parameters."""
31+
app = app_factory(
32+
upstream_url=source_api_server,
33+
guard={
34+
"cls": "stac_auth_proxy.guards.cel",
35+
"args": ('has(req.query_params.foo) && req.query_params.foo == "bar"',),
36+
},
37+
)
38+
client = TestClient(app, headers={"Authorization": f"Bearer {token_builder({})}"})
39+
response = client.get(endpoint)
40+
assert response.status_code == expected_status_code
41+
42+
43+
@pytest.mark.parametrize(
44+
"token_payload, expected_status_code",
45+
[
46+
({"foo": "bar"}, 403),
47+
({"collections": []}, 403),
48+
({"collections": ["foo", "bar"]}, 403),
49+
({"collections": ["xyz"]}, 200),
50+
({"collections": ["foo", "xyz"]}, 200),
51+
],
52+
)
53+
def test_guard_auth_token(
54+
source_api_server,
55+
token_builder,
56+
token_payload,
57+
expected_status_code,
58+
):
59+
"""Test guard with auth token."""
60+
app = app_factory(
61+
upstream_url=source_api_server,
62+
guard={
63+
"cls": "stac_auth_proxy.guards.cel",
64+
"args": (
65+
"""
66+
has(req.path_params.collection_id) && has(token.collections) &&
67+
req.path_params.collection_id in (token.collections)
68+
""",
69+
),
70+
},
71+
)
72+
client = TestClient(
73+
app, headers={"Authorization": f"Bearer {token_builder(token_payload)}"}
74+
)
75+
response = client.get("/collections/xyz")
76+
assert response.status_code == expected_status_code

tests/test_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""Tests for OpenAPI spec handling."""
2+
3+
import pytest
4+
5+
from stac_auth_proxy.utils import extract_variables
6+
7+
8+
@pytest.mark.parametrize(
9+
"url, expected",
10+
(
11+
("/collections/123", {"collection_id": "123"}),
12+
("/collections/123/items", {"collection_id": "123"}),
13+
("/collections/123/bulk_items", {"collection_id": "123"}),
14+
("/collections/123/items/456", {"collection_id": "123", "item_id": "456"}),
15+
("/collections/123/bulk_items/456", {"collection_id": "123", "item_id": "456"}),
16+
("/other/123", {}),
17+
),
18+
)
19+
def test_extract_variables(url, expected):
20+
"""Test extracting variables from a URL path."""
21+
assert extract_variables(url) == expected

0 commit comments

Comments
 (0)