Skip to content

Commit 00b6a8a

Browse files
committed
Initial pass at DI tooling
1 parent 2ea5ab9 commit 00b6a8a

File tree

12 files changed

+155
-60
lines changed

12 files changed

+155
-60
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ requires = ["hatchling>=1.12.0"]
4343
dev = [
4444
"jwcrypto>=1.5.6",
4545
"pre-commit>=3.5.0",
46+
"pytest-asyncio>=0.25.1",
4647
"pytest-cov>=5.0.0",
4748
"pytest>=8.3.3",
4849
]

src/stac_auth_proxy/app.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -40,38 +40,14 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
4040

4141
proxy_handler = ReverseProxyHandler(
4242
upstream=str(settings.upstream_url),
43-
collections_filter=(
44-
settings.collections_filter(auth_scheme.maybe_validated_user)
45-
if settings.collections_filter
46-
else None
47-
),
48-
items_filter=(
49-
settings.items_filter(auth_scheme.maybe_validated_user)
50-
if settings.items_filter
51-
else None
52-
),
43+
auth_dependency=auth_scheme.maybe_validated_user,
44+
collections_filter=settings.collections_filter,
45+
items_filter=settings.items_filter,
5346
)
5447
openapi_handler = build_openapi_spec_handler(
5548
proxy=proxy_handler,
5649
oidc_config_url=str(settings.oidc_discovery_url),
5750
)
58-
59-
# TODO: How can we inject the collections_filter into only the endpoints that need it?
60-
# for endpoint, methods in settings.collections_filter_endpoints.items():
61-
# app.add_api_route(
62-
# endpoint,
63-
# partial(
64-
# proxy_handler.stream, collections_filter=settings.collections_filter
65-
# ),
66-
# methods=methods,
67-
# dependencies=[Security(auth_scheme.maybe_validated_user)],
68-
# )
69-
70-
# skip = [
71-
# settings.openapi_spec_endpoint,
72-
# *settings.collections_filter_endpoints,
73-
# ]
74-
7551
# Endpoints that are explicitely marked private
7652
for path, methods in settings.private_endpoints.items():
7753
app.add_api_route(

src/stac_auth_proxy/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@ class ClassInput(BaseModel):
1717
args: Sequence[str] = Field(default_factory=list)
1818
kwargs: dict[str, str] = Field(default_factory=dict)
1919

20-
def __call__(self, token_dependency):
20+
def __call__(self):
2121
"""Dynamically load a class and instantiate it with args & kwargs."""
2222
module_path, class_name = self.cls.rsplit(".", 1)
2323
module = importlib.import_module(module_path)
2424
cls = getattr(module, class_name)
25-
return cls(*self.args, **self.kwargs, token_dependency=token_dependency)
25+
return cls(*self.args, **self.kwargs)
2626

2727

2828
class Settings(BaseSettings):

src/stac_auth_proxy/filters/template.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
"""Generate CQL2 filter expressions via Jinja2 templating."""
22

3-
from typing import Annotated, Any, Callable
3+
from typing import Annotated, Any
44

55
from cql2 import Expr
6-
from fastapi import Request, Security
6+
from fastapi import Request
77
from jinja2 import BaseLoader, Environment
88

99
from ..utils.requests import extract_variables
1010

1111

12-
def Template(template_str: str, token_dependency: Callable[..., Any]):
12+
def Template(template_str: str):
1313
"""Generate CQL2 filter expressions via Jinja2 templating."""
1414
env = Environment(loader=BaseLoader).from_string(template_str)
1515

1616
async def dependency(
1717
request: Request,
18-
auth_token: Annotated[dict[str, Any], Security(token_dependency)],
18+
auth_token: Annotated[dict[str, Any], ...],
1919
) -> Expr:
2020
"""Render a CQL2 filter expression with the request and auth token."""
2121
# TODO: How to handle the case where auth_token is null?

src/stac_auth_proxy/handlers/reverse_proxy.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from starlette.datastructures import MutableHeaders
1313
from starlette.responses import StreamingResponse
1414

15-
from ..utils import filters
15+
from ..utils import di, filters
1616

1717
logger = logging.getLogger(__name__)
1818

@@ -22,6 +22,8 @@ class ReverseProxyHandler:
2222
"""Reverse proxy functionality."""
2323

2424
upstream: str
25+
auth_dependency: Callable
26+
2527
client: httpx.AsyncClient = None
2628

2729
# Filters
@@ -34,21 +36,21 @@ def __post_init__(self):
3436
base_url=self.upstream,
3537
timeout=httpx.Timeout(timeout=15.0),
3638
)
39+
self.collections_filter = (
40+
self.collections_filter() if self.collections_filter else None
41+
)
42+
self.items_filter = self.items_filter() if self.items_filter else None
3743

38-
# Update annotations to support FastAPI's dependency injection
39-
for endpoint in [self.proxy_request, self.stream]:
40-
endpoint.__annotations__["collections_filter"] = Annotated[
44+
# Inject auth dependency into filters
45+
for endpoint in [self.collections_filter, self.items_filter]:
46+
if not endpoint:
47+
continue
48+
endpoint.__annotations__["auth_token"] = Annotated[
4149
Optional[Expr],
42-
Depends(self.collections_filter or (lambda: None)),
50+
Depends(self.auth_dependency),
4351
]
4452

45-
async def proxy_request(
46-
self,
47-
request: Request,
48-
*,
49-
collections_filter: Annotated[Optional[Expr], Depends(...)] = None,
50-
stream=False,
51-
) -> httpx.Response:
53+
async def proxy_request(self, request: Request, *, stream=False) -> httpx.Response:
5254
"""Proxy a request to the upstream STAC API."""
5355
headers = MutableHeaders(request.headers)
5456
headers.setdefault("X-Forwarded-For", request.client.host)
@@ -58,12 +60,23 @@ async def proxy_request(
5860
query = request.url.query
5961

6062
# Apply filters
61-
if filters.is_collection_endpoint(path) and collections_filter:
62-
if request.method == "GET" and path == "/collections":
63+
if filters.is_collection_endpoint(path) and self.collections_filter:
64+
collections_filter = await di.call_with_injected_dependencies(
65+
func=self.collections_filter,
66+
request=request,
67+
)
68+
if request.method == "GET":
6369
query = filters.insert_filter(qs=query, filter=collections_filter)
64-
elif filters.is_item_endpoint(path) and self.items_filter:
70+
else:
71+
# TODO: Augment body
72+
...
73+
74+
if filters.is_item_endpoint(path) and self.items_filter:
6575
if request.method == "GET":
6676
query = filters.insert_filter(qs=query, filter=self.items_filter)
77+
else:
78+
# TODO: Augment body
79+
...
6780

6881
# https://github.com/fastapi/fastapi/discussions/7382#discussioncomment-5136466
6982
rp_req = self.client.build_request(
@@ -87,15 +100,11 @@ async def proxy_request(
87100
rp_resp.headers["X-Upstream-Time"] = f"{proxy_time:.3f}"
88101
return rp_resp
89102

90-
async def stream(
91-
self,
92-
request: Request,
93-
collections_filter: Annotated[Optional[Expr], Depends(...)],
94-
) -> StreamingResponse:
103+
async def stream(self, request: Request) -> StreamingResponse:
95104
"""Transparently proxy a request to the upstream STAC API."""
96105
rp_resp = await self.proxy_request(
97106
request,
98-
collections_filter=collections_filter,
107+
# collections_filter=collections_filter,
99108
stream=True,
100109
)
101110
return StreamingResponse(
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Utils module for stac_auth_proxy."""

src/stac_auth_proxy/utils/di.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,50 @@
1+
"""Dependency injection utilities for FastAPI."""
2+
3+
import asyncio
4+
5+
from fastapi import Request
16
from fastapi.dependencies.models import Dependant
7+
from fastapi.dependencies.utils import (
8+
get_parameterless_sub_dependant,
9+
solve_dependencies,
10+
)
11+
from fastapi.params import Depends
212

313

414
def has_any_security_requirements(dependency: Dependant) -> bool:
515
"""
616
Recursively check if any dependency within the hierarchy has a non-empty
717
security_requirements list.
818
"""
9-
if dependency.security_requirements:
10-
return True
11-
return any(
19+
return dependency.security_requirements or any(
1220
has_any_security_requirements(sub_dep) for sub_dep in dependency.dependencies
1321
)
22+
23+
24+
async def call_with_injected_dependencies(func, request: Request):
25+
"""
26+
Manually solves and injects dependencies for `func` using FastAPI's internal
27+
dependency injection machinery.
28+
"""
29+
dependant = get_parameterless_sub_dependant(
30+
depends=Depends(dependency=func),
31+
path=request.url.path,
32+
)
33+
34+
solved = await solve_dependencies(
35+
request=request,
36+
dependant=dependant,
37+
# response=response,
38+
# body=request.body,
39+
body=None,
40+
async_exit_stack=None,
41+
embed_body_fields=False,
42+
)
43+
44+
if solved.errors:
45+
raise RuntimeError(f"Dependency resolution error: {solved.errors}")
46+
47+
results = func(**solved.values)
48+
if asyncio.iscoroutine(results):
49+
return await results
50+
return results

src/stac_auth_proxy/utils/requests.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
import re
1+
"""Utility functions for working with HTTP requests."""
22

3+
import re
34
from urllib.parse import urlparse
5+
46
from httpx import Headers
57

68

tests/test_di.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""Tests for the dependency injection utility."""
2+
3+
from typing import Annotated
4+
5+
import pytest
6+
from fastapi import Depends, Request
7+
8+
from stac_auth_proxy.utils.di import call_with_injected_dependencies
9+
10+
11+
async def get_db_connection():
12+
"""Mock asynchronous function to get a DB connection."""
13+
# pretend you open a DB connection or retrieve a session
14+
return "some_db_connection"
15+
16+
17+
def get_special_value():
18+
"""Mock synchronous function to get a special value."""
19+
return 42
20+
21+
22+
async def async_func_with_dependencies(
23+
db_conn: Annotated[str, Depends(get_db_connection)],
24+
special_value: Annotated[int, Depends(get_special_value)],
25+
):
26+
"""Mock asynchronous dependency."""
27+
return (db_conn, special_value)
28+
29+
30+
def sync_func_with_dependencies(
31+
db_conn: Annotated[str, Depends(get_db_connection)],
32+
special_value: Annotated[int, Depends(get_special_value)],
33+
):
34+
"""Mock synchronous dependency."""
35+
return (db_conn, special_value)
36+
37+
38+
@pytest.mark.parametrize(
39+
"func",
40+
[async_func_with_dependencies, sync_func_with_dependencies],
41+
)
42+
@pytest.mark.asyncio
43+
async def test_di(func):
44+
"""Test dependency injection."""
45+
request = Request(
46+
scope={
47+
"type": "http",
48+
"method": "GET",
49+
"path": "/test",
50+
"headers": [],
51+
"query_string": b"",
52+
}
53+
)
54+
55+
result = await call_with_injected_dependencies(func, request=request)
56+
assert result == ("some_db_connection", 42)

tests/test_filters_jinja2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from urllib.parse import parse_qs
44

5-
import httpx
65
import pytest
76
from fastapi.testclient import TestClient
87
from utils import AppFactory

0 commit comments

Comments
 (0)