Skip to content

Commit 7131140

Browse files
committed
feat: add row filters engine
1 parent d38ddc3 commit 7131140

File tree

20 files changed

+852
-313
lines changed

20 files changed

+852
-313
lines changed

policies/celine/dataset.rego

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,13 @@ allow if {
193193
input.action.name in ["query", "read"]
194194
}
195195

196+
allow if {
197+
is_internal
198+
is_user
199+
"viewers" in input.subject.groups
200+
input.action.name in ["query", "read"]
201+
}
202+
196203
reason := "internal dataset - manager group granted" if {
197204
allow
198205
is_internal

src/celine/dataset/api/dataset_query/executor.py

Lines changed: 87 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33

44
import json
55
import logging
6+
7+
import httpx
8+
import sqlglot
69
from typing import Optional, Dict, Sequence, List
7-
from fastapi import HTTPException
10+
from fastapi import HTTPException, Request
811
from sqlalchemy import RowMapping, Table, text, select, func
912
from sqlalchemy.ext.asyncio import AsyncSession
1013
from sqlalchemy.exc import DBAPIError
@@ -19,11 +22,12 @@
1922
)
2023
from celine.dataset.security.models import AuthenticatedUser
2124
from celine.dataset.api.dataset_query.parser import parse_sql_query
22-
from celine.dataset.api.dataset_query.user_filter import (
23-
get_user_filter_column,
24-
inject_user_filter,
25-
is_admin_user,
25+
from celine.dataset.api.dataset_query.row_filters import (
26+
apply_row_filter_plans,
27+
get_row_filter_registry,
28+
get_row_filter_specs,
2629
)
30+
from celine.dataset.api.dataset_query.row_filters.utils import is_admin_user
2731

2832
logger = logging.getLogger(__name__)
2933

@@ -103,20 +107,15 @@ async def execute_query(
103107
- SQL validated (SELECT-only, table allowlist)
104108
- LIMIT/OFFSET enforced server-side
105109
- hard row cap applied
106-
- user filtering applied (if userFilterColumn defined)
110+
- row-level filters applied (pluggable governance handlers)
107111
"""
108-
# ------------------------------------------------------------------
109-
# Validate SQL
110-
# ------------------------------------------------------------------
111-
112112
if raw_sql is None or raw_sql.strip() == "":
113113
raise HTTPException(400, "sql query not provided")
114114

115115
logger.debug(f"Parsing raw SQL: {raw_sql}")
116116
try:
117117
parsed = parse_sql_query(raw_sql)
118-
except HTTPException as exc:
119-
logger.error(f"SQL validation failed: {exc}")
118+
except HTTPException:
120119
raise
121120
except Exception as exc:
122121
logger.exception("SQL validation failed")
@@ -126,12 +125,16 @@ async def execute_query(
126125
raise HTTPException(400, "Query references no datasets")
127126

128127
datasets = await resolve_datasets_for_tables(db=db, table_names=parsed.tables)
128+
129129
tables_map: dict[str, str] = {}
130-
user_filters: List[dict] = [] # Collect filters to apply
130+
row_filter_plans = []
131+
132+
registry = get_row_filter_registry()
131133

132134
for ref_table, ds in datasets.items():
133135
if not ds.expose:
134136
raise HTTPException(403, "Dataset not available")
137+
135138
await enforce_dataset_access(entry=ds, user=user)
136139

137140
if ds.backend_config is None:
@@ -148,45 +151,80 @@ async def execute_query(
148151
logger.debug(f"Mapped SQL table {ref_table} -> {phy_table_name}")
149152
tables_map[ref_table] = phy_table_name
150153

151-
# ------------------------------------------------------------------
152-
# Check for user filtering requirement
153-
# ------------------------------------------------------------------
154-
filter_column = get_user_filter_column(ds)
155-
if filter_column:
156-
if user is None:
154+
specs = get_row_filter_specs(ds)
155+
if not specs:
156+
continue
157+
158+
if user is None:
159+
raise HTTPException(
160+
401,
161+
f"Dataset {ref_table} requires authentication for row filtering",
162+
)
163+
164+
if is_admin_user(user):
165+
continue
166+
167+
for spec in specs:
168+
handler_name = spec.get("handler")
169+
args = spec.get("args") or {}
170+
if not isinstance(handler_name, str) or not handler_name:
157171
raise HTTPException(
158-
401,
159-
f"Dataset {ref_table} requires authentication for user filtering",
172+
500,
173+
f"Invalid row filter spec for dataset {ref_table}: missing handler",
174+
)
175+
if not isinstance(args, dict):
176+
raise HTTPException(
177+
500,
178+
f"Invalid row filter spec for dataset {ref_table}: args must be object",
160179
)
161180

162-
# Admins bypass user filtering
163-
if not is_admin_user(user):
164-
user_filters.append(
165-
{
166-
"table": phy_table_name,
167-
"column": filter_column,
168-
"user_sub": user.sub,
169-
}
181+
try:
182+
plan = await registry.resolve_with_cache(
183+
handler_name=handler_name,
184+
table=phy_table_name,
185+
user=user,
186+
args=args,
187+
request_context={},
188+
)
189+
except KeyError:
190+
logger.error(
191+
f"Unknown row filter handler '{handler_name}' for dataset {ref_table}"
192+
)
193+
raise HTTPException(
194+
500,
195+
f"Unknown row filter handler '{handler_name}' for dataset {ref_table}",
170196
)
171-
logger.debug(
172-
f"User filter required for {ref_table}: "
173-
f"{filter_column} = {user.sub}"
197+
except httpx.HTTPError:
198+
logger.error(f"Row filter resolution failed for dataset {ref_table}")
199+
raise HTTPException(
200+
403,
201+
f"Row filter resolution failed for dataset {ref_table}",
202+
)
203+
except Exception as e:
204+
logger.error(f"Row filter handler failed: {e}")
205+
raise HTTPException(
206+
500,
207+
f"Row filter handler '{handler_name}' failed for dataset {ref_table}",
174208
)
175209

176-
# Replace tables ID with physical tables
177-
complete_sql = parsed.to_sql(tables_map=tables_map)
178-
logger.debug(f"Complete SQL (before user filter): {complete_sql}")
210+
row_filter_plans.append(plan)
179211

180-
# ------------------------------------------------------------------
181-
# Inject user filters
182-
# ------------------------------------------------------------------
183-
if user_filters:
184-
complete_sql = inject_user_filter(complete_sql, user_filters)
185-
logger.debug(f"Complete SQL (after user filter): {complete_sql}")
212+
# Logical -> physical substitution
213+
complete_sql = parsed.to_sql(tables_map=tables_map)
214+
logger.debug(f"Complete SQL (after table mapping): {complete_sql}")
215+
216+
# Apply row-level filters
217+
if row_filter_plans:
218+
try:
219+
ast = sqlglot.parse_one(complete_sql)
220+
ast = apply_row_filter_plans(ast, row_filter_plans)
221+
complete_sql = ast.sql()
222+
except Exception:
223+
logger.exception("Failed to apply row filters")
224+
raise HTTPException(500, "Failed to apply row filters") from None
225+
logger.debug(f"Complete SQL (after row filters): {complete_sql}")
186226

187-
# ------------------------------------------------------------------
188227
# Pagination & caps
189-
# ------------------------------------------------------------------
190228
limit = _clamp_limit(limit)
191229
offset = max(offset, 0)
192230

@@ -204,40 +242,29 @@ async def execute_query(
204242
) AS q
205243
"""
206244

207-
# ------------------------------------------------------------------
208245
# Execute count
209-
# ------------------------------------------------------------------
210246
try:
211-
total = await execute_scalar_with_timeout(
212-
db,
213-
count_sql,
214-
)
215-
except HTTPException as e:
216-
logger.error(f"Query count failed: {e}")
247+
total = await execute_scalar_with_timeout(db, count_sql)
248+
except HTTPException:
217249
raise
218-
except Exception as exc: # safety net
250+
except Exception:
219251
logger.exception("Count query failed")
220252
raise HTTPException(500, "Query failed") from None
221253

222-
# ------------------------------------------------------------------
223254
# Execute data query
224-
# ------------------------------------------------------------------
225255
try:
226256
rows = await execute_rows_with_timeout(
227257
db,
228258
paginated_sql,
229259
{"limit": limit, "offset": offset},
230260
)
231-
except HTTPException as e:
232-
logger.error(f"Query execution failed: {e}")
261+
except HTTPException:
233262
raise
234-
except Exception as exc: # safety net
235-
logger.error("Query execution failed: {e}")
263+
except Exception:
264+
logger.exception("Query execution failed")
236265
raise HTTPException(500, "Query execution failed") from None
237266

238-
# ------------------------------------------------------------------
239267
# Post-process rows (geometry → GeoJSON)
240-
# ------------------------------------------------------------------
241268
items = []
242269
for r in rows:
243270
row = dict(r)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from __future__ import annotations
2+
3+
from .registry import RowFilterRegistry, get_row_filter_registry
4+
from .specs import get_row_filter_specs
5+
from .apply import apply_row_filter_plans
6+
from .models import RowFilterPlan
7+
8+
__all__ = [
9+
"RowFilterRegistry",
10+
"get_row_filter_registry",
11+
"RowFilterPlan",
12+
"get_row_filter_specs",
13+
"apply_row_filter_plans",
14+
]
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
from typing import Iterable
5+
6+
import sqlglot
7+
from sqlglot import exp
8+
9+
from celine.dataset.api.dataset_query.row_filters.models import RowFilterPlan
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
def _table_name(table: exp.Table) -> str:
15+
# table.this is Identifier; may contain dots if set that way
16+
ident = table.args.get("this")
17+
if isinstance(ident, exp.Identifier):
18+
return ident.this
19+
return table.sql()
20+
21+
22+
def _qualify_columns(expr: exp.Expression, alias: str) -> exp.Expression:
23+
"""Qualify unqualified Column nodes in expr with alias."""
24+
e = expr.copy()
25+
for col in e.find_all(exp.Column):
26+
if col.args.get("table") is None:
27+
col.set("table", exp.Identifier(this=alias, quoted=False))
28+
return e
29+
30+
31+
def _add_where(select: exp.Select, condition: exp.Expression) -> None:
32+
existing = select.args.get("where")
33+
if isinstance(existing, exp.Where):
34+
new_cond = exp.And(this=existing.this, expression=condition)
35+
existing.set("this", new_cond)
36+
else:
37+
select.set("where", exp.Where(this=condition))
38+
39+
40+
def _tables_in_select(select: exp.Select) -> list[exp.Table]:
41+
tables: list[exp.Table] = []
42+
for t in select.find_all(exp.Table):
43+
anc = t.find_ancestor(exp.Select)
44+
if anc is select:
45+
tables.append(t)
46+
return tables
47+
48+
49+
def apply_row_filter_plans(ast: exp.Expression, plans: Iterable[RowFilterPlan]) -> exp.Expression:
50+
"""Apply row filter plans to an AST (returns a modified copy)."""
51+
plans_by_table: dict[str, list[RowFilterPlan]] = {}
52+
for p in plans:
53+
plans_by_table.setdefault(p.table, []).append(p)
54+
55+
out = ast.copy()
56+
57+
# If any plan is deny -> inject FALSE predicate at top-level
58+
for ps in plans_by_table.values():
59+
for p in ps:
60+
if p.kind == "deny":
61+
top = out.find(exp.Select)
62+
if top is None:
63+
return out
64+
_add_where(top, exp.Boolean(this=False))
65+
return out
66+
67+
for select in out.find_all(exp.Select):
68+
tables = _tables_in_select(select)
69+
for table in tables:
70+
name = _table_name(table)
71+
if name not in plans_by_table:
72+
continue
73+
74+
alias = table.alias_or_name
75+
for plan in plans_by_table[name]:
76+
if plan.kind != "predicate" or plan.predicate_template is None:
77+
continue
78+
cond = _qualify_columns(plan.predicate_template, alias)
79+
_add_where(select, cond)
80+
81+
return out
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from __future__ import annotations
2+
3+
import time
4+
from dataclasses import dataclass
5+
from typing import Any, Callable, Generic, Optional, TypeVar
6+
7+
T = TypeVar("T")
8+
9+
10+
@dataclass
11+
class _CacheEntry(Generic[T]):
12+
value: T
13+
expires_at: float
14+
15+
16+
class TTLCache(Generic[T]):
17+
"""Very small in-memory TTL cache (process-local)."""
18+
19+
def __init__(self, maxsize: int = 10_000):
20+
self._maxsize = maxsize
21+
self._store: dict[str, _CacheEntry[T]] = {}
22+
23+
def get(self, key: str) -> Optional[T]:
24+
e = self._store.get(key)
25+
if e is None:
26+
return None
27+
if e.expires_at <= time.time():
28+
self._store.pop(key, None)
29+
return None
30+
return e.value
31+
32+
def set(self, key: str, value: T, ttl_seconds: int) -> None:
33+
if ttl_seconds <= 0:
34+
return
35+
if len(self._store) >= self._maxsize:
36+
# naive eviction: drop expired, otherwise drop arbitrary oldest-like key
37+
now = time.time()
38+
for k in list(self._store.keys()):
39+
if self._store[k].expires_at <= now:
40+
self._store.pop(k, None)
41+
if len(self._store) >= self._maxsize:
42+
self._store.pop(next(iter(self._store.keys())), None)
43+
44+
self._store[key] = _CacheEntry(value=value, expires_at=time.time() + ttl_seconds)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from __future__ import annotations
2+
3+
from .direct_user_match import DirectUserMatchHandler
4+
from .http_in_list import HttpInListHandler
5+
from .table_pointer import TablePointerHandler
6+
from .rec_registry import RecRegistryHandler
7+
8+
__all__ = [
9+
"DirectUserMatchHandler",
10+
"HttpInListHandler",
11+
"TablePointerHandler",
12+
"RecRegistryHandler",
13+
]

0 commit comments

Comments
 (0)