Skip to content

Commit 2ea5ab9

Browse files
committed
Reorg utils
1 parent 67f6b3a commit 2ea5ab9

File tree

10 files changed

+106
-81
lines changed

10 files changed

+106
-81
lines changed

src/stac_auth_proxy/app.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,22 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
5656
oidc_config_url=str(settings.oidc_discovery_url),
5757
)
5858

59+
# TODO: How can we inject the collections_filter into only the endpoints that need it?
60+
# for endpoint, methods in settings.collections_filter_endpoints.items():
61+
# app.add_api_route(
62+
# endpoint,
63+
# partial(
64+
# proxy_handler.stream, collections_filter=settings.collections_filter
65+
# ),
66+
# methods=methods,
67+
# dependencies=[Security(auth_scheme.maybe_validated_user)],
68+
# )
69+
70+
# skip = [
71+
# settings.openapi_spec_endpoint,
72+
# *settings.collections_filter_endpoints,
73+
# ]
74+
5975
# Endpoints that are explicitely marked private
6076
for path, methods in settings.private_endpoints.items():
6177
app.add_api_route(

src/stac_auth_proxy/config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class ClassInput(BaseModel):
1818
kwargs: dict[str, str] = Field(default_factory=dict)
1919

2020
def __call__(self, token_dependency):
21-
"""Dynamically load a class and instantiate it with kwargs."""
21+
"""Dynamically load a class and instantiate it with args & kwargs."""
2222
module_path, class_name = self.cls.rsplit(".", 1)
2323
module = importlib.import_module(module_path)
2424
cls = getattr(module, class_name)
@@ -49,6 +49,10 @@ class Settings(BaseSettings):
4949
openapi_spec_endpoint: Optional[str] = None
5050

5151
collections_filter: Optional[ClassInput] = None
52+
collections_filter_endpoints: Optional[EndpointMethods] = {
53+
"/collections": ["GET"],
54+
"/collections/{collection_id}": ["GET"],
55+
}
5256
items_filter: Optional[ClassInput] = None
5357

5458
model_config = SettingsConfigDict(env_prefix="STAC_AUTH_PROXY_")

src/stac_auth_proxy/filters/template.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from fastapi import Request, Security
77
from jinja2 import BaseLoader, Environment
88

9-
from ..utils import extract_variables
9+
from ..utils.requests import extract_variables
1010

1111

1212
def Template(template_str: str, token_dependency: Callable[..., Any]):

src/stac_auth_proxy/handlers/open_api_spec.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from fastapi import Request, Response
66
from fastapi.routing import APIRoute
77

8-
from ..utils import has_any_security_requirements, safe_headers
8+
from ..utils.di import has_any_security_requirements
9+
from ..utils.requests import safe_headers
910
from .reverse_proxy import ReverseProxyHandler
1011

1112
logger = logging.getLogger(__name__)

src/stac_auth_proxy/handlers/reverse_proxy.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from starlette.datastructures import MutableHeaders
1313
from starlette.responses import StreamingResponse
1414

15-
from .. import utils
15+
from ..utils import filters
1616

1717
logger = logging.getLogger(__name__)
1818

@@ -23,6 +23,8 @@ class ReverseProxyHandler:
2323

2424
upstream: str
2525
client: httpx.AsyncClient = None
26+
27+
# Filters
2628
collections_filter: Optional[Callable] = None
2729
items_filter: Optional[Callable] = None
2830

@@ -55,13 +57,13 @@ async def proxy_request(
5557
path = request.url.path
5658
query = request.url.query
5759

58-
# Appliy filters
59-
if utils.is_collection_endpoint(path) and collections_filter:
60+
# Apply filters
61+
if filters.is_collection_endpoint(path) and collections_filter:
6062
if request.method == "GET" and path == "/collections":
61-
query = utils.insert_filter(qs=query, filter=collections_filter)
62-
elif utils.is_item_endpoint(path) and self.items_filter:
63+
query = filters.insert_filter(qs=query, filter=collections_filter)
64+
elif filters.is_item_endpoint(path) and self.items_filter:
6365
if request.method == "GET":
64-
query = utils.insert_filter(qs=query, filter=self.items_filter)
66+
query = filters.insert_filter(qs=query, filter=self.items_filter)
6567

6668
# https://github.com/fastapi/fastapi/discussions/7382#discussioncomment-5136466
6769
rp_req = self.client.build_request(

src/stac_auth_proxy/utils.py

Lines changed: 0 additions & 72 deletions
This file was deleted.

src/stac_auth_proxy/utils/__init__.py

Whitespace-only changes.

src/stac_auth_proxy/utils/di.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from fastapi.dependencies.models import Dependant
2+
3+
4+
def has_any_security_requirements(dependency: Dependant) -> bool:
5+
"""
6+
Recursively check if any dependency within the hierarchy has a non-empty
7+
security_requirements list.
8+
"""
9+
if dependency.security_requirements:
10+
return True
11+
return any(
12+
has_any_security_requirements(sub_dep) for sub_dep in dependency.dependencies
13+
)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""Utility functions."""
2+
3+
from urllib.parse import parse_qs, urlencode
4+
5+
from cql2 import Expr
6+
7+
8+
def insert_filter(qs: str, filter: Expr) -> str:
9+
"""Insert a filter expression into a query string. If a filter already exists, combine them."""
10+
qs_dict = parse_qs(qs)
11+
12+
filters = [Expr(f) for f in qs_dict.get("filter", [])]
13+
filters.append(filter)
14+
15+
combined_filter = Expr(" AND ".join(e.to_text() for e in filters))
16+
combined_filter.validate()
17+
18+
qs_dict["filter"] = [combined_filter.to_text()]
19+
20+
return urlencode(qs_dict, doseq=True)
21+
22+
23+
def is_collection_endpoint(path: str) -> bool:
24+
"""Check if the path is a collection endpoint."""
25+
# TODO: Expand this to cover all cases where a collection filter should be applied
26+
return path == "/collections"
27+
28+
29+
def is_item_endpoint(path: str) -> bool:
30+
"""Check if the path is an item endpoint."""
31+
# TODO: Expand this to cover all cases where an item filter should be applied
32+
return path == "/collection/{collection_id}/items"
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import re
2+
3+
from urllib.parse import urlparse
4+
from httpx import Headers
5+
6+
7+
def safe_headers(headers: Headers) -> dict[str, str]:
8+
"""Scrub headers that should not be proxied to the client."""
9+
excluded_headers = [
10+
"content-length",
11+
"content-encoding",
12+
]
13+
return {
14+
key: value
15+
for key, value in headers.items()
16+
if key.lower() not in excluded_headers
17+
}
18+
19+
20+
def extract_variables(url: str) -> dict:
21+
"""
22+
Extract variables from a URL path. Being that we use a catch-all endpoint for the proxy,
23+
we can't rely on the path parameters that FastAPI provides.
24+
"""
25+
path = urlparse(url).path
26+
# This allows either /items or /bulk_items, with an optional item_id following.
27+
pattern = r"^/collections/(?P<collection_id>[^/]+)(?:/(?:items|bulk_items)(?:/(?P<item_id>[^/]+))?)?/?$"
28+
match = re.match(pattern, path)
29+
return {k: v for k, v in match.groupdict().items() if v} if match else {}

0 commit comments

Comments
 (0)