Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .github/workflows/cicd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,19 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v3
- uses: pre-commit/[email protected]

test:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v3

- name: Install uv
uses: astral-sh/setup-uv@v4
with:
enable-cache: true

- name: Run tests
run: |
uv run pytest
7 changes: 7 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"python.testing.pytestArgs": [
"tests"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ known_first_party = ["stac_auth_proxy"]
profile = "black"

[tool.ruff]
ignore = ["E501", "D203", "D212"]
ignore = ["E501", "D205", "D212"]
select = ["D", "E", "F"]

[build-system]
Expand All @@ -39,4 +39,5 @@ requires = ["hatchling>=1.12.0"]
[dependency-groups]
dev = [
"pre-commit>=3.5.0",
"pytest>=8.3.3",
]
5 changes: 5 additions & 0 deletions src/stac_auth_proxy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,8 @@
It includes FastAPI routes for handling authentication, authorization, and interaction
with some internal STAC API.
"""

from .app import create_app
from .config import Settings

__all__ = ["create_app", "Settings"]
2 changes: 1 addition & 1 deletion src/stac_auth_proxy/__main__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Main module for the STAC Auth Proxy."""
"""Entry point for running the module without customized code."""

import uvicorn
from uvicorn.config import LOGGING_CONFIG
Expand Down
56 changes: 56 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Pytest fixtures."""

import threading

import pytest
import uvicorn
from fastapi import FastAPI


@pytest.fixture(scope="session")
def source_api():
"""Create upstream API for testing purposes."""
app = FastAPI(docs_url="/api.html", openapi_url="/api")

for path, methods in {
"/": ["GET"],
"/conformance": ["GET"],
"/queryables": ["GET"],
"/search": ["GET", "POST"],
"/collections": ["GET", "POST"],
"/collections/{collection_id}": ["GET", "PUT", "DELETE"],
"/collections/{collection_id}/items": ["GET", "POST"],
"/collections/{collection_id}/items/{item_id}": [
"GET",
"PUT",
"DELETE",
],
"/collections/{collection_id}/bulk_items": ["POST"],
}.items():
for method in methods:
# NOTE: declare routes per method separately to avoid warning of "Duplicate Operation ID ... for function <lambda>"
app.add_api_route(
path,
lambda: {"id": f"Response from {method}@{path}"},
methods=[method],
)

return app


@pytest.fixture(scope="session")
def source_api_server(source_api):
"""Run the source API in a background thread."""
host, port = "127.0.0.1", 8000
server = uvicorn.Server(
uvicorn.Config(
source_api,
host=host,
port=port,
)
)
thread = threading.Thread(target=server.run)
thread.start()
yield f"http://{host}:{port}"
server.should_exit = True
thread.join()
105 changes: 105 additions & 0 deletions tests/test_defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""Basic test cases for the proxy app."""

import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient

from stac_auth_proxy import Settings, create_app


@pytest.fixture(scope="module")
def test_app(source_api_server: str) -> FastAPI:
"""Fixture for the proxy app, pointing to the source API."""
return create_app(
Settings.model_validate(
{
"upstream_url": source_api_server,
"oidc_discovery_url": "https://samples.auth0.com/.well-known/openid-configuration",
"default_public": False,
},
)
)


@pytest.mark.parametrize(
"path,method,expected_status",
[
("/", "GET", 200),
("/conformance", "GET", 200),
("/queryables", "GET", 200),
("/search", "GET", 200),
("/search", "POST", 200),
("/collections", "GET", 200),
("/collections", "POST", 403),
("/collections/example-collection", "GET", 200),
("/collections/example-collection", "PUT", 403),
("/collections/example-collection", "DELETE", 403),
("/collections/example-collection/items", "GET", 200),
("/collections/example-collection/items", "POST", 403),
("/collections/example-collection/items/example-item", "GET", 200),
("/collections/example-collection/items/example-item", "PUT", 403),
("/collections/example-collection/items/example-item", "DELETE", 403),
("/collections/example-collection/bulk_items", "POST", 403),
("/api.html", "GET", 200),
("/api", "GET", 200),
],
)
def test_default_public_true(source_api_server, path, method, expected_status):
"""
When default_public=true and private_endpoints aren't set, all endpoints should be
public except for transaction endpoints.
"""
test_app = create_app(
Settings.model_validate(
{
"upstream_url": source_api_server,
"oidc_discovery_url": "https://samples.auth0.com/.well-known/openid-configuration",
"default_public": True,
},
)
)
client = TestClient(test_app)
response = client.request(method=method, url=path)
assert response.status_code == expected_status


@pytest.mark.parametrize(
"path,method,expected_status",
[
("/", "GET", 403),
("/conformance", "GET", 403),
("/queryables", "GET", 403),
("/search", "GET", 403),
("/search", "POST", 403),
("/collections", "GET", 403),
("/collections", "POST", 403),
("/collections/example-collection", "GET", 403),
("/collections/example-collection", "PUT", 403),
("/collections/example-collection", "DELETE", 403),
("/collections/example-collection/items", "GET", 403),
("/collections/example-collection/items", "POST", 403),
("/collections/example-collection/items/example-item", "GET", 403),
("/collections/example-collection/items/example-item", "PUT", 403),
("/collections/example-collection/items/example-item", "DELETE", 403),
("/collections/example-collection/bulk_items", "POST", 403),
("/api.html", "GET", 200),
("/api", "GET", 200),
],
)
def test_default_public_false(source_api_server, path, method, expected_status):
"""
When default_public=false and private_endpoints aren't set, all endpoints should be
public except for transaction endpoints.
"""
test_app = create_app(
Settings.model_validate(
{
"upstream_url": source_api_server,
"oidc_discovery_url": "https://samples.auth0.com/.well-known/openid-configuration",
"default_public": False,
},
)
)
client = TestClient(test_app)
response = client.request(method=method, url=path)
assert response.status_code == expected_status
44 changes: 44 additions & 0 deletions tests/test_openapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Tests for OpenAPI spec handling."""

from fastapi import FastAPI
from fastapi.testclient import TestClient

from stac_auth_proxy import Settings, create_app


def test_no_edit_openapi_spec(source_api_server):
"""When no OpenAPI spec endpoint is set, the proxied OpenAPI spec is unaltered."""
app = create_app(
Settings(
upstream_url=source_api_server,
oidc_discovery_url="https://samples.auth0.com/.well-known/openid-configuration",
openapi_spec_endpoint=None,
)
)
client = TestClient(app)
response = client.get("/api")
assert response.status_code == 200
openapi = response.json()
assert "info" in openapi
assert "openapi" in openapi
assert "paths" in openapi
assert "oidcAuth" not in openapi.get("components", {}).get("securitySchemes", {})


def test_oidc_in_openapi_spec(source_api: FastAPI, source_api_server: str):
"""When OpenAPI spec endpoint is set, the proxied OpenAPI spec is augmented with oidc details."""
app = create_app(
Settings(
upstream_url=source_api_server,
oidc_discovery_url="https://samples.auth0.com/.well-known/openid-configuration",
openapi_spec_endpoint=source_api.openapi_url,
)
)
client = TestClient(app)
response = client.get(source_api.openapi_url)
assert response.status_code == 200
openapi = response.json()
assert "info" in openapi
assert "openapi" in openapi
assert "paths" in openapi
assert "oidcAuth" in openapi.get("components", {}).get("securitySchemes", {})
Loading