Skip to content

Commit 475f95a

Browse files
committed
Crude working collections filter
1 parent 29b4806 commit 475f95a

File tree

5 files changed

+168
-38
lines changed

5 files changed

+168
-38
lines changed

src/stac_auth_proxy/app.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@
66
"""
77

88
import logging
9-
from typing import Optional
9+
from typing import Optional, Annotated
1010

11-
from fastapi import Depends, FastAPI
11+
from fastapi import FastAPI, Security, Request, Depends
12+
from cql2 import Expr
1213

1314
from .auth import OpenIdConnectAuth
1415
from .config import Settings
1516
from .handlers import OpenApiSpecHandler, ReverseProxyHandler
1617
from .middleware import AddProcessTimeHeaderMiddleware
18+
from .utils import apply_filter
1719

1820
logger = logging.getLogger(__name__)
1921

@@ -28,8 +30,8 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
2830
app.add_middleware(AddProcessTimeHeaderMiddleware)
2931

3032
auth_scheme = OpenIdConnectAuth(
31-
openid_configuration_url=str(settings.oidc_discovery_url)
32-
).valid_token_dependency
33+
openid_configuration_url=settings.oidc_discovery_url
34+
)
3335

3436
if settings.debug:
3537
app.add_api_route(
@@ -38,12 +40,40 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
3840
methods=["GET"],
3941
)
4042

41-
proxy_handler = ReverseProxyHandler(upstream=str(settings.upstream_url))
43+
collections_filter = (
44+
settings.collections_filter(auth_scheme.maybe_validated_user)
45+
if settings.collections_filter
46+
else None
47+
)
48+
items_filter = (
49+
settings.items_filter(auth_scheme.maybe_validated_user)
50+
if settings.items_filter
51+
else None
52+
)
53+
proxy_handler = ReverseProxyHandler(
54+
upstream=str(settings.upstream_url),
55+
collections_filter=collections_filter,
56+
items_filter=items_filter,
57+
)
4258
openapi_handler = OpenApiSpecHandler(
4359
proxy=proxy_handler,
4460
oidc_config_url=str(settings.oidc_discovery_url),
4561
)
4662

63+
# @app.get("/collections")
64+
# async def collections(
65+
# request: Request,
66+
# filter: Annotated[Optional[Expr], Depends(collections_filter.dependency)],
67+
# ):
68+
# # if filter:
69+
# # print(f"{request.receive=}")
70+
# # request = await apply_filter(
71+
# # request,
72+
# # filter,
73+
# # )
74+
# # print(f"{request.receive=}")
75+
# return await proxy_handler.stream(request=request)
76+
4777
# Endpoints that are explicitely marked private
4878
for path, methods in settings.private_endpoints.items():
4979
app.add_api_route(
@@ -54,7 +84,7 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
5484
else openapi_handler.dispatch
5585
),
5686
methods=methods,
57-
dependencies=[Depends(auth_scheme)],
87+
dependencies=[Security(auth_scheme.validated_user)],
5888
)
5989

6090
# Endpoints that are explicitely marked as public
@@ -67,14 +97,23 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
6797
else openapi_handler.dispatch
6898
),
6999
methods=methods,
100+
dependencies=[Security(auth_scheme.maybe_validated_user)],
70101
)
71102

72103
# Catchall for remainder of the endpoints
73104
app.add_api_route(
74105
"/{path:path}",
75106
proxy_handler.stream,
76107
methods=["GET", "POST", "PUT", "PATCH", "DELETE"],
77-
dependencies=([] if settings.default_public else [Depends(auth_scheme)]),
108+
dependencies=(
109+
[
110+
Security(
111+
auth_scheme.maybe_validated_user
112+
if settings.default_public
113+
else auth_scheme.validated_user
114+
)
115+
]
116+
),
78117
)
79118

80119
return app

src/stac_auth_proxy/auth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class OpenIdConnectAuth:
3131
def __post_init__(self):
3232
"""Initialize the OIDC authentication class."""
3333
logger.debug("Requesting OIDC config")
34-
origin_url = (
34+
origin_url = str(
3535
self.openid_configuration_internal_url or self.openid_configuration_url
3636
)
3737
with urllib.request.urlopen(origin_url) as response:

src/stac_auth_proxy/filters/template.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,31 +19,37 @@ class Template:
1919

2020
# Generated attributes
2121
env: Environment = field(init=False)
22+
dependency: Callable[[Request, Security], Expr] = field(init=False)
2223

2324
def __post_init__(self):
2425
"""Initialize the Jinja2 environment."""
2526
self.env = Environment(loader=BaseLoader).from_string(self.template_str)
26-
self.render.__annotations__["auth_token"] = Security(self.token_dependency)
27-
28-
async def cql2(self, request: Request, auth_token=Security(...)) -> Expr:
29-
"""Render a CQL2 filter expression with the request and auth token."""
30-
# TODO: How to handle the case where auth_token is null?
31-
context = {
32-
"req": {
33-
"path": request.url.path,
34-
"method": request.method,
35-
"query_params": dict(request.query_params),
36-
"path_params": extract_variables(request.url.path),
37-
"headers": dict(request.headers),
38-
"body": (
39-
await request.json()
40-
if request.headers.get("content-type") == "application/json"
41-
else (await request.body()).decode()
42-
),
43-
},
44-
"token": auth_token,
45-
}
46-
cql2_str = self.env.render(**context)
47-
cql2_expr = Expr(cql2_str)
48-
cql2_expr.validate()
49-
return cql2_expr
27+
self.dependency = self.build()
28+
29+
def build(self):
30+
async def dependency(
31+
request: Request, auth_token=Security(self.token_dependency)
32+
) -> Expr:
33+
"""Render a CQL2 filter expression with the request and auth token."""
34+
# TODO: How to handle the case where auth_token is null?
35+
context = {
36+
"req": {
37+
"path": request.url.path,
38+
"method": request.method,
39+
"query_params": dict(request.query_params),
40+
"path_params": extract_variables(request.url.path),
41+
"headers": dict(request.headers),
42+
"body": (
43+
await request.json()
44+
if request.headers.get("content-type") == "application/json"
45+
else (await request.body()).decode()
46+
),
47+
},
48+
"token": auth_token,
49+
}
50+
cql2_str = self.env.render(**context)
51+
cql2_expr = Expr(cql2_str)
52+
cql2_expr.validate()
53+
return cql2_expr
54+
55+
return dependency

src/stac_auth_proxy/handlers/reverse_proxy.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,17 @@
33
import logging
44
import time
55
from dataclasses import dataclass
6+
from typing import Optional, Annotated
67

8+
from cql2 import Expr
79
import httpx
8-
from fastapi import Request
10+
from fastapi import Request, Depends
911
from starlette.background import BackgroundTask
1012
from starlette.datastructures import MutableHeaders
1113
from starlette.responses import StreamingResponse
1214

15+
from ..utils import update_qs
16+
1317
logger = logging.getLogger(__name__)
1418

1519

@@ -19,6 +23,8 @@ class ReverseProxyHandler:
1923

2024
upstream: str
2125
client: httpx.AsyncClient = None
26+
collections_filter: Optional[callable] = None
27+
items_filter: Optional[callable] = None
2228

2329
def __post_init__(self):
2430
"""Initialize the HTTP client."""
@@ -27,17 +33,41 @@ def __post_init__(self):
2733
timeout=httpx.Timeout(timeout=15.0),
2834
)
2935

30-
async def proxy_request(self, request: Request, *, stream=False) -> httpx.Response:
36+
self.proxy_request.__annotations__["collections_filter"] = Annotated[
37+
Optional[Expr], Depends(self.collections_filter.dependency)
38+
]
39+
self.stream.__annotations__["collections_filter"] = Annotated[
40+
Optional[Expr], Depends(self.collections_filter.dependency)
41+
]
42+
43+
async def proxy_request(
44+
self,
45+
request: Request,
46+
*,
47+
collections_filter: Annotated[Optional[Expr], Depends(...)],
48+
stream=False,
49+
) -> httpx.Response:
3150
"""Proxy a request to the upstream STAC API."""
3251
headers = MutableHeaders(request.headers)
3352
headers.setdefault("X-Forwarded-For", request.client.host)
3453
headers.setdefault("X-Forwarded-Host", request.url.hostname)
3554

55+
path = request.url.path
56+
query = request.url.query.encode("utf-8")
57+
3658
# https://github.com/fastapi/fastapi/discussions/7382#discussioncomment-5136466
59+
# TODO: Examine filters
60+
if collections_filter:
61+
if request.method == "GET" and path == "/collections":
62+
query += b"&" + update_qs(
63+
request.query_params, filter=collections_filter.to_text()
64+
)
65+
3766
url = httpx.URL(
38-
path=request.url.path,
39-
query=request.url.query.encode("utf-8"),
67+
path=path,
68+
query=query,
4069
)
70+
4171
rp_req = self.client.build_request(
4272
request.method,
4373
url=url,
@@ -56,9 +86,17 @@ async def proxy_request(self, request: Request, *, stream=False) -> httpx.Respon
5686
rp_resp.headers["X-Upstream-Time"] = f"{proxy_time:.3f}"
5787
return rp_resp
5888

59-
async def stream(self, request: Request) -> StreamingResponse:
89+
async def stream(
90+
self,
91+
request: Request,
92+
collections_filter: Annotated[Optional[Expr], Depends(...)],
93+
) -> StreamingResponse:
6094
"""Transparently proxy a request to the upstream STAC API."""
61-
rp_resp = await self.proxy_request(request, stream=True)
95+
rp_resp = await self.proxy_request(
96+
request,
97+
collections_filter=collections_filter,
98+
stream=True,
99+
)
62100
return StreamingResponse(
63101
rp_resp.aiter_raw(),
64102
status_code=rp_resp.status_code,

src/stac_auth_proxy/utils.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
import re
44
from urllib.parse import urlparse
55

6+
from cql2 import Expr
7+
from fastapi import Request
68
from fastapi.dependencies.models import Dependant
9+
from starlette.datastructures import QueryParams
710
from httpx import Headers
811

912

@@ -42,3 +45,47 @@ def has_any_security_requirements(dependency: Dependant) -> bool:
4245
return any(
4346
has_any_security_requirements(sub_dep) for sub_dep in dependency.dependencies
4447
)
48+
49+
50+
async def apply_filter(request: Request, filter: Expr) -> Request:
51+
"""Apply a CQL2 filter to a request."""
52+
req_filter = request.query_params.get("filter") or (
53+
(await request.json()).get("filter")
54+
if request.headers.get("content-length")
55+
else None
56+
)
57+
58+
new_filter = Expr(" AND ".join(e.to_text() for e in [req_filter, filter] if e))
59+
new_filter.validate()
60+
61+
if request.method == "GET":
62+
updated_scope = request.scope.copy()
63+
updated_scope["query_string"] = update_qs(
64+
request.query_params,
65+
filter=new_filter.to_text(),
66+
)
67+
return Request(
68+
scope=updated_scope,
69+
receive=request.receive,
70+
# send=request._send,
71+
)
72+
73+
# TODO: Support POST/PUT/PATCH
74+
# elif request.method == "POST":
75+
# request_body = await request.body()
76+
# query = request.url.query
77+
# query += "&" if query else "?"
78+
# query += f"filter={filter}"
79+
# request.url.query = query
80+
81+
return request
82+
83+
84+
def update_qs(query_params: QueryParams, **updates) -> bytes:
85+
query_dict = {
86+
**query_params,
87+
**updates,
88+
}
89+
return "&".join(f"{key}={value}" for key, value in query_dict.items()).encode(
90+
"utf-8"
91+
)

0 commit comments

Comments
 (0)