Skip to content

Commit a23fd66

Browse files
authored
Refactor codebase to use middleware pattern (#20)
* Refactor API spec augmentation to middleware * Breakout middleware into separate files * Rename middleware * Reorg * chore: reorg imports * Mv auth to middleware, rework url patterns to regex * fix: pass empty chunks while waiting for entirety of body * Rm scope scheck in auth enforcement * typing cleanup * fix: extract user from request state * Rm auth class in favor of middleware * Rework template tooling as dataclass * cleanup: rm unused code * Cleanup * CQL2 Middleware: in progress * In progress * Update tests * Take care to separate GET and POST middleware behavior * Conditionally enable CQL2 filters + cleanup * Add formatting helpers * Test tweaks * Fix querystring parse test util * bugfix: place auth middleware after cql2 middleware to ensure correct order * Allow customizing state key * Minor update to auth token * pre-commit cleanup * Rm unused code
1 parent 1e3d823 commit a23fd66

23 files changed

+494
-433
lines changed

.vscode/launch.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
],
1717
"jinja": true,
1818
"cwd": "${workspaceFolder}/src",
19+
"justMyCode": false
1920
}
2021
]
2122
}

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,4 @@ dev = [
5050

5151
[tool.pytest.ini_options]
5252
asyncio_default_fixture_loop_scope = "function"
53-
asyncio_mode = "strict"
53+
asyncio_mode = "auto"

src/stac_auth_proxy/app.py

Lines changed: 37 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,17 @@
88
import logging
99
from typing import Optional
1010

11-
from fastapi import FastAPI, Security
11+
from fastapi import FastAPI
1212

13-
from .auth import OpenIdConnectAuth
1413
from .config import Settings
15-
from .handlers import ReverseProxyHandler, build_openapi_spec_handler
16-
from .middleware import AddProcessTimeHeaderMiddleware
14+
from .handlers import ReverseProxyHandler
15+
from .middleware import (
16+
AddProcessTimeHeaderMiddleware,
17+
ApplyCql2FilterMiddleware,
18+
BuildCql2FilterMiddleware,
19+
EnforceAuthMiddleware,
20+
OpenApiMiddleware,
21+
)
1722

1823
logger = logging.getLogger(__name__)
1924

@@ -25,56 +30,47 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
2530
app = FastAPI(
2631
openapi_url=None, # Disable OpenAPI schema endpoint, we want to serve upstream's schema
2732
)
33+
2834
app.add_middleware(AddProcessTimeHeaderMiddleware)
2935

36+
if settings.openapi_spec_endpoint:
37+
app.add_middleware(
38+
OpenApiMiddleware,
39+
openapi_spec_path=settings.openapi_spec_endpoint,
40+
oidc_config_url=str(settings.oidc_discovery_url),
41+
private_endpoints=settings.private_endpoints,
42+
default_public=settings.default_public,
43+
)
44+
45+
if settings.items_filter:
46+
app.add_middleware(ApplyCql2FilterMiddleware)
47+
app.add_middleware(
48+
BuildCql2FilterMiddleware,
49+
# collections_filter=settings.collections_filter,
50+
items_filter=settings.items_filter(),
51+
)
52+
53+
app.add_middleware(
54+
EnforceAuthMiddleware,
55+
public_endpoints=settings.public_endpoints,
56+
private_endpoints=settings.private_endpoints,
57+
default_public=settings.default_public,
58+
oidc_config_url=settings.oidc_discovery_url,
59+
)
60+
3061
if settings.debug:
3162
app.add_api_route(
3263
"/_debug",
3364
lambda: {"settings": settings},
3465
methods=["GET"],
3566
)
3667

37-
# Tooling
38-
auth_scheme = OpenIdConnectAuth(
39-
openid_configuration_url=settings.oidc_discovery_url
40-
)
41-
proxy_handler = ReverseProxyHandler(
42-
upstream=str(settings.upstream_url),
43-
auth_dependency=auth_scheme.maybe_validated_user,
44-
collections_filter=settings.collections_filter,
45-
items_filter=settings.items_filter,
46-
)
47-
openapi_handler = build_openapi_spec_handler(
48-
proxy=proxy_handler,
49-
oidc_config_url=str(settings.oidc_discovery_url),
50-
)
51-
52-
# Configure security dependency for explicitely specified endpoints
53-
for path_methods, dependencies in [
54-
(settings.private_endpoints, [Security(auth_scheme.validated_user)]),
55-
(settings.public_endpoints, []),
56-
]:
57-
for path, methods in path_methods.items():
58-
endpoint = (
59-
openapi_handler
60-
if path == settings.openapi_spec_endpoint
61-
else proxy_handler.stream
62-
)
63-
app.add_api_route(
64-
path,
65-
endpoint=endpoint,
66-
methods=methods,
67-
dependencies=dependencies,
68-
)
69-
70-
# Catchall for remainder of the endpoints
68+
# Catchall for any endpoint
69+
proxy_handler = ReverseProxyHandler(upstream=str(settings.upstream_url))
7170
app.add_api_route(
7271
"/{path:path}",
7372
proxy_handler.stream,
7473
methods=["GET", "POST", "PUT", "PATCH", "DELETE"],
75-
dependencies=(
76-
[] if settings.default_public else [Security(auth_scheme.validated_user)]
77-
),
7874
)
7975

8076
return app

src/stac_auth_proxy/config.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,26 +37,26 @@ class Settings(BaseSettings):
3737
default_public: bool = False
3838
private_endpoints: EndpointMethods = {
3939
# https://github.com/stac-api-extensions/collection-transaction/blob/v1.0.0-beta.1/README.md#methods
40-
"/collections": ["POST"],
41-
"/collections/{collection_id}": ["PUT", "PATCH", "DELETE"],
40+
r"^/collections$": ["POST"],
41+
r"^/collections/([^/]+)$": ["PUT", "PATCH", "DELETE"],
4242
# https://github.com/stac-api-extensions/transaction/blob/v1.0.0-rc.3/README.md#methods
43-
"/collections/{collection_id}/items": ["POST"],
44-
"/collections/{collection_id}/items/{item_id}": ["PUT", "PATCH", "DELETE"],
43+
r"^/collections/([^/]+)/items$": ["POST"],
44+
r"^/collections/([^/]+)/items/([^/]+)$": ["PUT", "PATCH", "DELETE"],
4545
# https://stac-utils.github.io/stac-fastapi/api/stac_fastapi/extensions/third_party/bulk_transactions/#bulktransactionextension
46-
"/collections/{collection_id}/bulk_items": ["POST"],
46+
r"^/collections/([^/]+)/bulk_items$": ["POST"],
4747
}
48-
public_endpoints: EndpointMethods = {"/api.html": ["GET"], "/api": ["GET"]}
48+
public_endpoints: EndpointMethods = {r"^/api.html$": ["GET"], r"^/api$": ["GET"]}
4949
openapi_spec_endpoint: Optional[str] = None
5050

51-
collections_filter: Optional[ClassInput] = None
52-
collections_filter_endpoints: Optional[EndpointMethods] = {
53-
"/collections": ["GET"],
54-
"/collections/{collection_id}": ["GET"],
55-
}
51+
# collections_filter: Optional[ClassInput] = None
52+
# collections_filter_endpoints: Optional[EndpointMethods] = {
53+
# r"^/collections$": ["GET"],
54+
# r"^/collections$/([^/]+)": ["GET"],
55+
# }
5656
items_filter: Optional[ClassInput] = None
5757
items_filter_endpoints: Optional[EndpointMethods] = {
58-
"/search": ["POST"],
59-
"/collections/{collection_id}/items": ["GET", "POST"],
58+
r"^/search$": ["POST"],
59+
r"^/collections/([^/]+)/items$": ["GET", "POST"],
6060
}
6161

6262
model_config = SettingsConfigDict(env_prefix="STAC_AUTH_PROXY_")
Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,27 @@
11
"""Generate CQL2 filter expressions via Jinja2 templating."""
22

3-
from typing import Annotated, Any
3+
from dataclasses import dataclass, field
4+
from typing import Any
45

56
from cql2 import Expr
6-
from fastapi import Request
77
from jinja2 import BaseLoader, Environment
88

9-
from ..utils.requests import extract_variables
109

11-
12-
def Template(template_str: str):
10+
@dataclass
11+
class Template:
1312
"""Generate CQL2 filter expressions via Jinja2 templating."""
14-
env = Environment(loader=BaseLoader).from_string(template_str)
1513

16-
async def dependency(
17-
request: Request,
18-
auth_token: Annotated[dict[str, Any], ...],
19-
) -> Expr:
14+
template_str: str
15+
env: Environment = field(init=False)
16+
17+
def __post_init__(self):
18+
"""Initialize the Jinja2 environment."""
19+
self.env = Environment(loader=BaseLoader).from_string(self.template_str)
20+
21+
async def __call__(self, context: dict[str, Any]) -> Expr:
2022
"""Render a CQL2 filter expression with the request and auth token."""
2123
# TODO: How to handle the case where auth_token is null?
22-
context = {
23-
"req": {
24-
"path": request.url.path,
25-
"method": request.method,
26-
"query_params": dict(request.query_params),
27-
"path_params": extract_variables(request.url.path),
28-
"headers": dict(request.headers),
29-
"body": (
30-
await request.json()
31-
if request.headers.get("content-type") == "application/json"
32-
else (await request.body()).decode()
33-
),
34-
},
35-
"token": auth_token,
36-
}
37-
cql2_str = env.render(**context).strip()
24+
cql2_str = self.env.render(**context).strip()
3825
cql2_expr = Expr(cql2_str)
3926
cql2_expr.validate()
4027
return cql2_expr
41-
42-
return dependency
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Handlers to process requests."""
22

3-
from .open_api_spec import build_openapi_spec_handler
43
from .reverse_proxy import ReverseProxyHandler
54

6-
__all__ = ["build_openapi_spec_handler", "ReverseProxyHandler"]
5+
__all__ = ["ReverseProxyHandler"]

src/stac_auth_proxy/handlers/open_api_spec.py

Lines changed: 0 additions & 68 deletions
This file was deleted.

0 commit comments

Comments
 (0)