Skip to content

Commit 76a52f1

Browse files
committed
refactor: expect filter generators to return string or dict, not Expr
1 parent 1b8fa28 commit 76a52f1

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

src/stac_auth_proxy/filters/template.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from dataclasses import dataclass, field
44
from typing import Any
55

6-
from cql2 import Expr
76
from jinja2 import BaseLoader, Environment
87

98

@@ -18,10 +17,6 @@ def __post_init__(self):
1817
"""Initialize the Jinja2 environment."""
1918
self.env = Environment(loader=BaseLoader).from_string(self.template_str)
2019

21-
async def __call__(self, context: dict[str, Any]) -> Expr:
20+
async def __call__(self, context: dict[str, Any]) -> str:
2221
"""Render a CQL2 filter expression with the request and auth token."""
23-
# TODO: How to handle the case where auth_token is null?
24-
cql2_str = self.env.render(**context).strip()
25-
cql2_expr = Expr(cql2_str)
26-
cql2_expr.validate()
27-
return cql2_expr
22+
return self.env.render(**context).strip()

src/stac_auth_proxy/middleware/BuildCql2FilterMiddleware.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import json
44
import re
55
from dataclasses import dataclass
6-
from typing import Callable, Optional
6+
from typing import Any, Awaitable, Callable, Optional
77

88
from cql2 import Expr
99
from starlette.requests import Request
@@ -37,7 +37,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
3737

3838
async def set_filter(body: Optional[dict] = None) -> None:
3939
assert filter_builder is not None
40-
cql2_filter = await filter_builder(
40+
filter_expr = await filter_builder(
4141
{
4242
"req": {
4343
"path": request.url.path,
@@ -50,6 +50,7 @@ async def set_filter(body: Optional[dict] = None) -> None:
5050
**scope["state"],
5151
}
5252
)
53+
cql2_filter = Expr(filter_expr)
5354
cql2_filter.validate()
5455
setattr(request.state, self.state_key, cql2_filter)
5556

@@ -76,7 +77,9 @@ async def receive_build_filter() -> Message:
7677

7778
return await self.app(scope, receive_build_filter, send)
7879

79-
def _get_filter(self, path: str) -> Optional[Callable[..., Expr]]:
80+
def _get_filter(
81+
self, path: str
82+
) -> Optional[Callable[..., Awaitable[str | dict[str, Any]]]]:
8083
"""Get the CQL2 filter builder for the given path."""
8184
endpoint_filters = [
8285
(r"^/collections(/[^/]+)?$", self.collections_filter),

0 commit comments

Comments
 (0)