Skip to content

Commit 4a6765a

Browse files
committed
Add first pass of using CEL for custom auth policies
1 parent 50b99e6 commit 4a6765a

File tree

7 files changed

+332
-0
lines changed

7 files changed

+332
-0
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ classifiers = [
88
dependencies = [
99
"authlib>=1.3.2",
1010
"brotli>=1.1.0",
11+
"cel-python>=0.1.5",
1112
"eoapi-auth-utils>=0.4.0",
1213
"fastapi>=0.115.5",
1314
"httpx>=0.28.0",

src/stac_auth_proxy/app.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
authentication, authorization, and proxying of requests to some internal STAC API.
66
"""
77

8+
import logging
89
from typing import Optional
910

1011
from eoapi.auth_utils import OpenIdConnectAuth
@@ -14,6 +15,8 @@
1415
from .handlers import OpenApiSpecHandler, ReverseProxyHandler
1516
from .middleware import AddProcessTimeHeaderMiddleware
1617

18+
logger = logging.getLogger(__name__)
19+
1720

1821
def create_app(settings: Optional[Settings] = None) -> FastAPI:
1922
"""FastAPI Application Factory."""
@@ -26,6 +29,11 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
2629
openid_configuration_url=str(settings.oidc_discovery_url)
2730
).valid_token_dependency
2831

32+
if settings.guard:
33+
logger.info("Wrapping auth scheme")
34+
auth_scheme = settings.guard(auth_scheme).check
35+
print(f"{auth_scheme=}")
36+
2937
proxy_handler = ReverseProxyHandler(upstream=str(settings.upstream_url))
3038
openapi_handler = OpenApiSpecHandler(
3139
proxy=proxy_handler, oidc_config_url=str(settings.oidc_discovery_url)

src/stac_auth_proxy/config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,27 @@
11
"""Configuration for the STAC Auth Proxy."""
22

3+
import importlib
34
from typing import Optional, TypeAlias
45

6+
from pydantic import BaseModel
57
from pydantic.networks import HttpUrl
68
from pydantic_settings import BaseSettings, SettingsConfigDict
79

810
EndpointMethods: TypeAlias = dict[str, list[str]]
911

1012

13+
class ClassInput(BaseModel):
14+
cls: str
15+
kwargs: Optional[dict[str, str]] = {}
16+
17+
def __call__(self, token_dependency):
18+
"""Dynamically load a class and instantiate it with kwargs."""
19+
module_path, class_name = self.cls.rsplit(".", 1)
20+
module = importlib.import_module(module_path)
21+
cls = getattr(module, class_name)
22+
return cls(**self.kwargs, token_dependency=token_dependency)
23+
24+
1125
class Settings(BaseSettings):
1226
"""Configuration settings for the STAC Auth Proxy."""
1327

@@ -30,3 +44,5 @@ class Settings(BaseSettings):
3044
openapi_spec_endpoint: Optional[str] = None
3145

3246
model_config = SettingsConfigDict(env_prefix="STAC_AUTH_PROXY_")
47+
48+
guard: Optional[ClassInput] = None

src/stac_auth_proxy/guards/__init__.py

Whitespace-only changes.

src/stac_auth_proxy/guards/cel.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from dataclasses import dataclass
2+
from typing import Any
3+
4+
from fastapi import Request, Depends, HTTPException
5+
import celpy
6+
7+
8+
@dataclass
9+
class Cel:
10+
"""Custom middleware."""
11+
12+
expression: str
13+
token_dependency: Any
14+
15+
def __post_init__(self):
16+
env = celpy.Environment()
17+
ast = env.compile(self.expression)
18+
self.program = env.program(ast)
19+
20+
async def check(
21+
request: Request,
22+
auth_token=Depends(self.token_dependency),
23+
):
24+
request_data = {
25+
"path": request.url.path,
26+
"method": request.method,
27+
"query_params": dict(request.query_params), # Convert to a dict
28+
"headers": dict(request.headers), # Convert headers to a dict if needed
29+
# Body may need to be read (await request.json()) or (await request.body()) if needed
30+
"body": (
31+
await request.json()
32+
if request.headers.get("content-type") == "application/json"
33+
else (await request.body()).decode()
34+
),
35+
}
36+
37+
activation = {"req": request_data, "token": auth_token}
38+
print(f"{activation=}")
39+
result = self.program.evaluate(celpy.json_to_cel(activation))
40+
print(f"{result=}")
41+
if not result:
42+
raise HTTPException(
43+
status_code=403, detail="Forbidden (failed CEL check)"
44+
)
45+
46+
self.check = check

tests/test_guard.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""Tests for OpenAPI spec handling."""
2+
3+
import pytest
4+
from fastapi.testclient import TestClient
5+
from utils import AppFactory
6+
7+
app_factory = AppFactory(
8+
oidc_discovery_url="https://samples.auth0.com/.well-known/openid-configuration",
9+
default_public=False,
10+
)
11+
12+
13+
import pytest
14+
from unittest.mock import patch, MagicMock
15+
16+
17+
# Fixture to patch OpenIdConnectAuth and mock valid_token_dependency
18+
@pytest.fixture
19+
def skip_auth():
20+
with patch("eoapi.auth_utils.OpenIdConnectAuth") as MockClass:
21+
# Create a mock instance
22+
mock_instance = MagicMock()
23+
# Set the return value of `valid_token_dependency`
24+
mock_instance.valid_token_dependency.return_value = "constant"
25+
# Assign the mock instance to the patched class's return value
26+
MockClass.return_value = mock_instance
27+
28+
# Yield the mock instance for use in tests
29+
yield mock_instance
30+
31+
32+
@pytest.mark.parametrize(
33+
"endpoint, expected_status_code",
34+
[
35+
("/", 403),
36+
("/?foo=xyz", 403),
37+
("/?foo=bar", 200),
38+
],
39+
)
40+
def test_guard_query_params(
41+
source_api_server,
42+
token_builder,
43+
endpoint,
44+
expected_status_code,
45+
):
46+
"""When no OpenAPI spec endpoint is set, the proxied OpenAPI spec is unaltered."""
47+
app = app_factory(
48+
upstream_url=source_api_server,
49+
guard={
50+
"cls": "stac_auth_proxy.guards.cel.Cel",
51+
"kwargs": {
52+
"expression": '("foo" in req.query_params) && req.query_params.foo == "bar"'
53+
},
54+
},
55+
)
56+
client = TestClient(app, headers={"Authorization": f"Bearer {token_builder({})}"})
57+
response = client.get(endpoint)
58+
assert response.status_code == expected_status_code

0 commit comments

Comments
 (0)