Skip to content

Commit e5eee66

Browse files
committed
Passing tests
1 parent 475f95a commit e5eee66

File tree

3 files changed

+51
-72
lines changed

3 files changed

+51
-72
lines changed

src/stac_auth_proxy/app.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,16 @@
66
"""
77

88
import logging
9-
from typing import Optional, Annotated
9+
from typing import Optional
1010

11-
from fastapi import FastAPI, Security, Request, Depends
12-
from cql2 import Expr
11+
from fastapi import FastAPI, Security
1312

1413
from .auth import OpenIdConnectAuth
1514
from .config import Settings
1615
from .handlers import OpenApiSpecHandler, ReverseProxyHandler
1716
from .middleware import AddProcessTimeHeaderMiddleware
18-
from .utils import apply_filter
17+
18+
# from .utils import apply_filter
1919

2020
logger = logging.getLogger(__name__)
2121

src/stac_auth_proxy/handlers/reverse_proxy.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33
import logging
44
import time
55
from dataclasses import dataclass
6-
from typing import Optional, Annotated
6+
from typing import Annotated, Optional
77

8-
from cql2 import Expr
98
import httpx
10-
from fastapi import Request, Depends
9+
from cql2 import Expr
10+
from fastapi import Depends, Request
1111
from starlette.background import BackgroundTask
1212
from starlette.datastructures import MutableHeaders
1313
from starlette.responses import StreamingResponse
1414

15-
from ..utils import update_qs
15+
from .. import utils
1616

1717
logger = logging.getLogger(__name__)
1818

@@ -33,18 +33,18 @@ def __post_init__(self):
3333
timeout=httpx.Timeout(timeout=15.0),
3434
)
3535

36-
self.proxy_request.__annotations__["collections_filter"] = Annotated[
37-
Optional[Expr], Depends(self.collections_filter.dependency)
38-
]
39-
self.stream.__annotations__["collections_filter"] = Annotated[
40-
Optional[Expr], Depends(self.collections_filter.dependency)
41-
]
36+
# Update annotations to support FastAPI's dependency injection
37+
for endpoint in [self.proxy_request, self.stream]:
38+
endpoint.__annotations__["collections_filter"] = Annotated[
39+
Optional[Expr],
40+
Depends(getattr(self.collections_filter, "dependency", lambda: None)),
41+
]
4242

4343
async def proxy_request(
4444
self,
4545
request: Request,
4646
*,
47-
collections_filter: Annotated[Optional[Expr], Depends(...)],
47+
collections_filter: Annotated[Optional[Expr], Depends(...)] = None,
4848
stream=False,
4949
) -> httpx.Response:
5050
"""Proxy a request to the upstream STAC API."""
@@ -53,24 +53,22 @@ async def proxy_request(
5353
headers.setdefault("X-Forwarded-Host", request.url.hostname)
5454

5555
path = request.url.path
56-
query = request.url.query.encode("utf-8")
56+
query = request.url.query
5757

58-
# https://github.com/fastapi/fastapi/discussions/7382#discussioncomment-5136466
59-
# TODO: Examine filters
60-
if collections_filter:
58+
if utils.is_collection_endpoint(path) and collections_filter:
6159
if request.method == "GET" and path == "/collections":
62-
query += b"&" + update_qs(
63-
request.query_params, filter=collections_filter.to_text()
64-
)
65-
66-
url = httpx.URL(
67-
path=path,
68-
query=query,
69-
)
60+
query = utils.insert_filter(qs=query, filter=collections_filter)
61+
elif utils.is_item_endpoint(path) and self.items_filter:
62+
if request.method == "GET":
63+
query = utils.insert_filter(qs=query, filter=self.items_filter)
7064

65+
# https://github.com/fastapi/fastapi/discussions/7382#discussioncomment-5136466
7166
rp_req = self.client.build_request(
7267
request.method,
73-
url=url,
68+
url=httpx.URL(
69+
path=path,
70+
query=query.encode("utf-8"),
71+
),
7472
headers=headers,
7573
content=request.stream(),
7674
)

src/stac_auth_proxy/utils.py

Lines changed: 25 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
"""Utility functions."""
22

33
import re
4-
from urllib.parse import urlparse
4+
from urllib.parse import parse_qs, urlencode, urlparse
55

66
from cql2 import Expr
7-
from fastapi import Request
87
from fastapi.dependencies.models import Dependant
9-
from starlette.datastructures import QueryParams
108
from httpx import Headers
119

1210

@@ -47,45 +45,28 @@ def has_any_security_requirements(dependency: Dependant) -> bool:
4745
)
4846

4947

50-
async def apply_filter(request: Request, filter: Expr) -> Request:
51-
"""Apply a CQL2 filter to a request."""
52-
req_filter = request.query_params.get("filter") or (
53-
(await request.json()).get("filter")
54-
if request.headers.get("content-length")
55-
else None
56-
)
48+
def insert_filter(qs: str, filter: Expr) -> str:
49+
"""Insert a filter expression into a query string. If a filter already exists, combine them."""
50+
qs_dict = parse_qs(qs)
5751

58-
new_filter = Expr(" AND ".join(e.to_text() for e in [req_filter, filter] if e))
59-
new_filter.validate()
60-
61-
if request.method == "GET":
62-
updated_scope = request.scope.copy()
63-
updated_scope["query_string"] = update_qs(
64-
request.query_params,
65-
filter=new_filter.to_text(),
66-
)
67-
return Request(
68-
scope=updated_scope,
69-
receive=request.receive,
70-
# send=request._send,
71-
)
72-
73-
# TODO: Support POST/PUT/PATCH
74-
# elif request.method == "POST":
75-
# request_body = await request.body()
76-
# query = request.url.query
77-
# query += "&" if query else "?"
78-
# query += f"filter={filter}"
79-
# request.url.query = query
80-
81-
return request
82-
83-
84-
def update_qs(query_params: QueryParams, **updates) -> bytes:
85-
query_dict = {
86-
**query_params,
87-
**updates,
88-
}
89-
return "&".join(f"{key}={value}" for key, value in query_dict.items()).encode(
90-
"utf-8"
91-
)
52+
filters = [Expr(f) for f in qs_dict.get("filter", [])]
53+
filters.append(filter)
54+
55+
combined_filter = Expr(" AND ".join(e.to_text() for e in filters))
56+
combined_filter.validate()
57+
58+
qs_dict["filter"] = [combined_filter.to_text()]
59+
60+
return urlencode(qs_dict, doseq=True)
61+
62+
63+
def is_collection_endpoint(path: str) -> bool:
64+
"""Check if the path is a collection endpoint."""
65+
# TODO: Expand this to cover all cases where a collection filter should be applied
66+
return path == "/collections"
67+
68+
69+
def is_item_endpoint(path: str) -> bool:
70+
"""Check if the path is an item endpoint."""
71+
# TODO: Expand this to cover all cases where an item filter should be applied
72+
return path == "/collection/{collection_id}/items"

0 commit comments

Comments
 (0)