Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

STAC Auth Proxy is a proxy API that mediates between the client and and some internally accessible STAC API in order to provide a flexible authentication mechanism.

## Features

- 🔐 Selectively apply OIDC auth to some or all endpoints & methods
- 📖 Augments [OpenAPI](https://swagger.io/specification/) with auth information, keeping auto-generated docs (e.g. [Swagger UI](https://swagger.io/tools/swagger-ui/)) accurate
- 💂‍♀️ Custom policies enforce complex access controls, defined with [Common Expression Language (CEL)](https://cel.dev/)

## Installation

Set up connection to upstream STAC API and the OpenID Connect provider by setting the following environment variables:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ classifiers = [
dependencies = [
"authlib>=1.3.2",
"brotli>=1.1.0",
"cel-python>=0.1.5",
"eoapi-auth-utils>=0.4.0",
"fastapi>=0.115.5",
"httpx>=0.28.0",
Expand Down
7 changes: 7 additions & 0 deletions src/stac_auth_proxy/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
authentication, authorization, and proxying of requests to some internal STAC API.
"""

import logging
from typing import Optional

from eoapi.auth_utils import OpenIdConnectAuth
Expand All @@ -14,6 +15,8 @@
from .handlers import OpenApiSpecHandler, ReverseProxyHandler
from .middleware import AddProcessTimeHeaderMiddleware

logger = logging.getLogger(__name__)


def create_app(settings: Optional[Settings] = None) -> FastAPI:
"""FastAPI Application Factory."""
Expand All @@ -26,6 +29,10 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
openid_configuration_url=str(settings.oidc_discovery_url)
).valid_token_dependency

if settings.guard:
logger.info("Wrapping auth scheme")
auth_scheme = settings.guard(auth_scheme)

proxy_handler = ReverseProxyHandler(upstream=str(settings.upstream_url))
openapi_handler = OpenApiSpecHandler(
proxy=proxy_handler, oidc_config_url=str(settings.oidc_discovery_url)
Expand Down
21 changes: 20 additions & 1 deletion src/stac_auth_proxy/config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,30 @@
"""Configuration for the STAC Auth Proxy."""

from typing import Optional, TypeAlias
import importlib
from typing import Optional, Sequence, TypeAlias

from pydantic import BaseModel
from pydantic.networks import HttpUrl
from pydantic_settings import BaseSettings, SettingsConfigDict

EndpointMethods: TypeAlias = dict[str, list[str]]


class ClassInput(BaseModel):
"""Input model for dynamically loading a class or function."""

cls: str
args: Optional[Sequence[str]] = []
kwargs: Optional[dict[str, str]] = {}

def __call__(self, token_dependency):
"""Dynamically load a class and instantiate it with kwargs."""
module_path, class_name = self.cls.rsplit(".", 1)
module = importlib.import_module(module_path)
cls = getattr(module, class_name)
return cls(*self.args, **self.kwargs, token_dependency=token_dependency)


class Settings(BaseSettings):
"""Configuration settings for the STAC Auth Proxy."""

Expand All @@ -30,3 +47,5 @@ class Settings(BaseSettings):
openapi_spec_endpoint: Optional[str] = None

model_config = SettingsConfigDict(env_prefix="STAC_AUTH_PROXY_")

guard: Optional[ClassInput] = None
5 changes: 5 additions & 0 deletions src/stac_auth_proxy/guards/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Module to define the access policy guards for the application."""

from .cel import cel

__all__ = ["cel"]
45 changes: 45 additions & 0 deletions src/stac_auth_proxy/guards/cel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Guard using CEL (Common Expression Language, https://cel.dev)."""

from typing import Any, Callable

import celpy
from fastapi import Depends, HTTPException, Request

from ..utils import extract_variables


def cel(expression: str, token_dependency: Callable[..., Any]):
"""Cel check factory."""
env = celpy.Environment()
ast = env.compile(expression)
program = env.program(ast)

async def check(
request: Request,
auth_token=Depends(token_dependency),
):
request_data = {
"path": request.url.path,
"method": request.method,
"query_params": dict(request.query_params),
"path_params": extract_variables(request.url.path),
"headers": dict(request.headers),
"body": (
await request.json()
if request.headers.get("content-type") == "application/json"
else (await request.body()).decode()
),
}

result = program.evaluate(
celpy.json_to_cel(
{
"req": request_data,
"token": auth_token,
}
)
)
if not result:
raise HTTPException(status_code=403, detail="Forbidden (failed CEL check)")

return check
15 changes: 15 additions & 0 deletions src/stac_auth_proxy/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Utility functions."""

import re
from urllib.parse import urlparse

from httpx import Headers


Expand All @@ -14,3 +17,15 @@ def safe_headers(headers: Headers) -> dict[str, str]:
for key, value in headers.items()
if key.lower() not in excluded_headers
}


def extract_variables(url: str) -> dict:
"""
Extract variables from a URL path. Being that we use a catch-all endpoint for the proxy,
we can't rely on the path parameters that FastAPI provides.
"""
path = urlparse(url).path
# This allows either /items or /bulk_items, with an optional item_id following.
pattern = r"^/collections/(?P<collection_id>[^/]+)(?:/(?:items|bulk_items)(?:/(?P<item_id>[^/]+))?)?/?$"
match = re.match(pattern, path)
return {k: v for k, v in match.groupdict().items() if v} if match else {}
76 changes: 76 additions & 0 deletions tests/test_guards_cel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Tests for CEL guard."""

import pytest
from fastapi.testclient import TestClient
from utils import AppFactory

app_factory = AppFactory(
oidc_discovery_url="https://samples.auth0.com/.well-known/openid-configuration",
default_public=False,
)


@pytest.mark.parametrize(
"endpoint, expected_status_code",
[
("/", 403),
("/?foo=xyz", 403),
("/?bar=foo", 403),
("/?foo=bar", 200),
("/?foo=xyz&foo=bar", 200), # Only the last value is checked
("/?foo=bar&foo=xyz", 403), # Only the last value is checked
],
)
def test_guard_query_params(
source_api_server,
token_builder,
endpoint,
expected_status_code,
):
"""Test guard with query parameters."""
app = app_factory(
upstream_url=source_api_server,
guard={
"cls": "stac_auth_proxy.guards.cel",
"args": ('has(req.query_params.foo) && req.query_params.foo == "bar"',),
},
)
client = TestClient(app, headers={"Authorization": f"Bearer {token_builder({})}"})
response = client.get(endpoint)
assert response.status_code == expected_status_code


@pytest.mark.parametrize(
"token_payload, expected_status_code",
[
({"foo": "bar"}, 403),
({"collections": []}, 403),
({"collections": ["foo", "bar"]}, 403),
({"collections": ["xyz"]}, 200),
({"collections": ["foo", "xyz"]}, 200),
],
)
def test_guard_auth_token(
source_api_server,
token_builder,
token_payload,
expected_status_code,
):
"""Test guard with auth token."""
app = app_factory(
upstream_url=source_api_server,
guard={
"cls": "stac_auth_proxy.guards.cel",
"args": (
"""
has(req.path_params.collection_id) && has(token.collections) &&
req.path_params.collection_id in (token.collections)
""",
),
},
)
client = TestClient(
app, headers={"Authorization": f"Bearer {token_builder(token_payload)}"}
)
response = client.get("/collections/xyz")
assert response.status_code == expected_status_code
21 changes: 21 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Tests for OpenAPI spec handling."""

import pytest

from stac_auth_proxy.utils import extract_variables


@pytest.mark.parametrize(
"url, expected",
(
("/collections/123", {"collection_id": "123"}),
("/collections/123/items", {"collection_id": "123"}),
("/collections/123/bulk_items", {"collection_id": "123"}),
("/collections/123/items/456", {"collection_id": "123", "item_id": "456"}),
("/collections/123/bulk_items/456", {"collection_id": "123", "item_id": "456"}),
("/other/123", {}),
),
)
def test_extract_variables(url, expected):
"""Test extracting variables from a URL path."""
assert extract_variables(url) == expected
Loading
Loading