Skip to content

Commit 25ca7a7

Browse files
authored
Add support for filter comparisons (MemMachine#676)
1 parent 56ff786 commit 25ca7a7

File tree

10 files changed

+241
-111
lines changed

10 files changed

+241
-111
lines changed

src/memmachine/common/episode_store/episode_sqlalchemy_store.py

Lines changed: 13 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from sqlalchemy.sql import Select
3131
from sqlalchemy.sql.elements import ColumnElement
3232

33-
from memmachine.common.data_types import FilterablePropertyValue
3433
from memmachine.common.episode_store.episode_model import Episode as EpisodeE
3534
from memmachine.common.episode_store.episode_model import EpisodeEntry, EpisodeType
3635
from memmachine.common.episode_store.episode_storage import EpisodeIdT, EpisodeStorage
@@ -45,6 +44,7 @@
4544
from memmachine.common.filter.filter_parser import (
4645
Or as FilterOr,
4746
)
47+
from memmachine.common.filter.sql_filter_util import parse_sql_filter
4848

4949

5050
class BaseEpisodeStore(DeclarativeBase):
@@ -246,7 +246,9 @@ def _apply_episode_filter(
246246
filters: list[ColumnElement[bool]] = []
247247

248248
if filter_expr is not None:
249-
filters.append(self._compile_episode_filter_expr(filter_expr))
249+
parsed_filter = self._compile_episode_filter_expr(filter_expr)
250+
if parsed_filter is not None:
251+
filters.append(parsed_filter)
250252

251253
if start_time is not None:
252254
filters.append(Episode.created_at >= start_time)
@@ -266,40 +268,18 @@ def _apply_episode_filter(
266268
def _compile_episode_comparison_expr(
267269
self,
268270
expr: FilterComparison,
269-
) -> ColumnElement[bool]:
271+
) -> ColumnElement[bool] | None:
270272
column, is_metadata = self._resolve_episode_field(expr.field)
271273

272-
if column is None:
273-
raise ValueError(f"Unsupported episode filter field: {expr.field}")
274-
275-
if expr.op == "=":
276-
value = expr.value
277-
if isinstance(value, list):
278-
raise ValueError("'=' comparison cannot accept list values")
279-
if is_metadata:
280-
value = self._normalize_metadata_value(value)
281-
return column == value
282-
return column == value
283-
284-
if expr.op == "in":
285-
if not isinstance(expr.value, list):
286-
raise ValueError("IN comparison requires a list of values")
287-
288-
values = expr.value
289-
if is_metadata:
290-
values = [self._normalize_metadata_value(v) for v in values]
291-
292-
return column.in_(values)
293-
294-
if expr.op == "is_null":
295-
return column.is_(None)
296-
297-
if expr.op == "is_not_null":
298-
return column.is_not(None)
299-
300-
raise ValueError(f"Unsupported operator: {expr.op}")
274+
return parse_sql_filter(
275+
column=column,
276+
is_metadata=is_metadata,
277+
expr=expr,
278+
)
301279

302-
def _compile_episode_filter_expr(self, expr: FilterExpr) -> ColumnElement[bool]:
280+
def _compile_episode_filter_expr(
281+
self, expr: FilterExpr
282+
) -> ColumnElement[bool] | None:
303283
if isinstance(expr, FilterComparison):
304284
return self._compile_episode_comparison_expr(expr)
305285

@@ -315,14 +295,6 @@ def _compile_episode_filter_expr(self, expr: FilterExpr) -> ColumnElement[bool]:
315295

316296
raise TypeError(f"Unsupported filter expression type: {type(expr)!r}")
317297

318-
@staticmethod
319-
def _normalize_metadata_value(
320-
value: FilterablePropertyValue | list[FilterablePropertyValue],
321-
) -> str:
322-
if isinstance(value, bool):
323-
return "true" if value else "false"
324-
return "" if value is None else str(value)
325-
326298
@staticmethod
327299
def _resolve_episode_field(
328300
field: str,

src/memmachine/common/filter/filter_parser.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class Comparison(FilterExpr):
2020
"""Filter comparison of a field against a value or list of values."""
2121

2222
field: str
23-
op: str # "=", "in", "is_null", "is_not_null"
23+
op: str # "=", "in", ">", "<", ">=", "<=", "is_null", "is_not_null"
2424
value: FilterablePropertyValue | list[FilterablePropertyValue]
2525

2626

@@ -57,7 +57,11 @@ class Token(NamedTuple):
5757
("LPAREN", r"\("),
5858
("RPAREN", r"\)"),
5959
("COMMA", r","),
60+
("GE", r">="),
61+
("LE", r"<="),
6062
("EQ", r"="),
63+
("GT", r">"),
64+
("LT", r"<"),
6165
("STRING", r"'[^']*'"),
6266
("IDENT", r"[A-Za-z0-9_\.]+"),
6367
("WS", r"\s+"),
@@ -155,10 +159,12 @@ def _parse_comparison(self) -> FilterExpr:
155159
field_tok = self._expect("IDENT")
156160
field = field_tok.value
157161

158-
if self._accept("EQ"):
159-
# field = value
162+
op_tok = self._accept("EQ", "GE", "LE", "GT", "LT")
163+
if op_tok:
164+
# field =/>=/>/</<= value
160165
value = self._parse_value()
161-
return Comparison(field=field, op="=", value=value)
166+
op = {"EQ": "=", "GE": ">=", "LE": "<=", "GT": ">", "LT": "<"}[op_tok.type]
167+
return Comparison(field=field, op=op, value=value)
162168

163169
if self._accept("IN"):
164170
self._expect("LPAREN")
@@ -179,7 +185,9 @@ def _parse_comparison(self) -> FilterExpr:
179185
op = "is_not_null" if negate else "is_null"
180186
return Comparison(field=field, op=op, value=None)
181187

182-
raise FilterParseError(f"Expected '=' or IN after field {field}")
188+
raise FilterParseError(
189+
f"Expected comparison operator (=, IN, >, <, >=, <=, IS) after field {field}"
190+
)
183191

184192
def _parse_value(self) -> FilterablePropertyValue:
185193
tok = self._expect("IDENT", "STRING")
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
"""SQLAlchemy utilities for FilterExpr."""
2+
3+
import logging
4+
from typing import Any
5+
6+
from sqlalchemy import ColumnElement
7+
from sqlalchemy.orm import InstrumentedAttribute, MappedColumn
8+
9+
from memmachine.common.data_types import FilterablePropertyValue
10+
from memmachine.common.filter.filter_parser import Comparison
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
def _normalize_metadata_value(
16+
value: FilterablePropertyValue | list[FilterablePropertyValue],
17+
) -> str:
18+
if isinstance(value, bool):
19+
return "true" if value else "false"
20+
return "" if value is None else str(value)
21+
22+
23+
def _ensure_scalar_value(
24+
value: FilterablePropertyValue | list[FilterablePropertyValue], op: str
25+
) -> FilterablePropertyValue:
26+
if isinstance(value, list):
27+
raise TypeError(f"'{op}' comparison cannot accept list values")
28+
return value
29+
30+
31+
def _ensure_list_value(
32+
value: FilterablePropertyValue | list[FilterablePropertyValue],
33+
) -> list[FilterablePropertyValue]:
34+
if not isinstance(value, list):
35+
raise TypeError("IN comparison requires a list of values")
36+
return value
37+
38+
39+
def parse_sql_filter(
40+
column: MappedColumn[Any] | InstrumentedAttribute[Any] | None,
41+
is_metadata: bool,
42+
expr: Comparison,
43+
) -> ColumnElement[bool] | None:
44+
"""Parse a FilterExpr comparison into an SQLAlchemy boolean expression."""
45+
if column is None:
46+
logger.warning("Unsupported feature filter field: %s", expr.field)
47+
return None
48+
49+
op = expr.op
50+
normalize = _normalize_metadata_value if is_metadata else lambda v: v
51+
52+
match op:
53+
case "is_null":
54+
return column.is_(None)
55+
case "is_not_null":
56+
return column.is_not(None)
57+
case "in":
58+
values = _ensure_list_value(expr.value)
59+
if is_metadata:
60+
values = [normalize(v) for v in values]
61+
return column.in_(values)
62+
case ">" | "<" | ">=" | "<=" | "=":
63+
value = normalize(_ensure_scalar_value(expr.value, op))
64+
return {
65+
">": column > value,
66+
"<": column < value,
67+
">=": column >= value,
68+
"<=": column <= value,
69+
"=": column == value,
70+
}[op]
71+
case _:
72+
raise ValueError(f"Unsupported operator: {expr.op}")

src/memmachine/common/vector_graph_store/neo4j_vector_graph_store.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1616,10 +1616,14 @@ def _render_filter_expr(
16161616
params: dict[
16171617
str, FilterablePropertyValue | list[FilterablePropertyValue]
16181618
] = {}
1619-
if expr.op == "=":
1619+
if expr.op in (">", "<", ">=", "<=", "="):
16201620
if isinstance(expr.value, list):
1621-
raise ValueError("'=' comparison cannot accept list values")
1622-
condition = f"{field_ref} = ${query_value_parameter}.{param_name}"
1621+
raise ValueError(
1622+
f"'{expr.op}' comparison cannot accept list values"
1623+
)
1624+
condition = (
1625+
f"{field_ref} {expr.op} ${query_value_parameter}.{param_name}"
1626+
)
16231627
params[param_name] = expr.value
16241628
elif expr.op == "in":
16251629
if not isinstance(expr.value, list):

src/memmachine/semantic_memory/storage/neo4j_semantic_storage.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -915,28 +915,7 @@ def _render_filter_expr(
915915
) -> tuple[str, dict[str, Any]]:
916916
if isinstance(expr, FilterComparison):
917917
field_ref = self._resolve_field_reference(alias, expr.field)
918-
params: dict[str, Any]
919-
if expr.op == "=":
920-
if isinstance(expr.value, list):
921-
raise ValueError("'=' comparison cannot accept list values")
922-
param_name = self._next_filter_param()
923-
condition = f"{field_ref} = ${param_name}"
924-
params = {param_name: expr.value}
925-
elif expr.op == "in":
926-
if not isinstance(expr.value, list):
927-
raise ValueError("IN comparison requires a list of values")
928-
param_name = self._next_filter_param()
929-
condition = f"{field_ref} IN ${param_name}"
930-
params = {param_name: expr.value}
931-
elif expr.op == "is_null":
932-
condition = f"{field_ref} IS NULL"
933-
params = {}
934-
elif expr.op == "is_not_null":
935-
condition = f"{field_ref} IS NOT NULL"
936-
params = {}
937-
else:
938-
raise ValueError(f"Unsupported operator: {expr.op}")
939-
return condition, params
918+
return self._render_comparison_condition(field_ref, expr)
940919
if isinstance(expr, FilterAnd):
941920
left_cond, left_params = self._render_filter_expr(alias, expr.left)
942921
right_cond, right_params = self._render_filter_expr(alias, expr.right)
@@ -951,6 +930,32 @@ def _render_filter_expr(
951930
return condition, left_params
952931
raise TypeError(f"Unsupported filter expression type: {type(expr)!r}")
953932

933+
def _render_comparison_condition(
934+
self, field_ref: str, expr: FilterComparison
935+
) -> tuple[str, dict[str, Any]]:
936+
op = expr.op
937+
params: dict[str, Any] = {}
938+
939+
if op == "in":
940+
if not isinstance(expr.value, list):
941+
raise ValueError("IN comparison requires a list of values")
942+
param = self._next_filter_param()
943+
return f"{field_ref} IN ${param}", {param: expr.value}
944+
945+
if op in (">", "<", ">=", "<=", "="):
946+
if isinstance(expr.value, list):
947+
raise ValueError(f"'{op}' comparison cannot accept list values")
948+
param = self._next_filter_param()
949+
return f"{field_ref} {op} ${param}", {param: expr.value}
950+
951+
if op == "is_null":
952+
return f"{field_ref} IS NULL", params
953+
954+
if op == "is_not_null":
955+
return f"{field_ref} IS NOT NULL", params
956+
957+
raise ValueError(f"Unsupported operator: {op}")
958+
954959
def _resolve_field_reference(self, alias: str, field: str) -> str:
955960
if field.startswith(("m.", "metadata.")):
956961
key = field.split(".", 1)[1]

src/memmachine/semantic_memory/storage/sqlalchemy_pgvector_semantic.py

Lines changed: 16 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,15 @@
3030
from sqlalchemy.dialects.postgresql import JSONB
3131
from sqlalchemy.engine import Connection
3232
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
33-
from sqlalchemy.orm import DeclarativeBase, aliased, mapped_column
33+
from sqlalchemy.orm import (
34+
DeclarativeBase,
35+
InstrumentedAttribute,
36+
MappedColumn,
37+
aliased,
38+
mapped_column,
39+
)
3440
from sqlalchemy.sql import Delete, Select, func
3541

36-
from memmachine.common.data_types import FilterablePropertyValue
3742
from memmachine.common.episode_store.episode_model import EpisodeIdT
3843
from memmachine.common.errors import InvalidArgumentError, ResourceNotFoundError
3944
from memmachine.common.filter.filter_parser import (
@@ -48,6 +53,7 @@
4853
from memmachine.common.filter.filter_parser import (
4954
Or as FilterOr,
5055
)
56+
from memmachine.common.filter.sql_filter_util import parse_sql_filter
5157
from memmachine.semantic_memory.semantic_model import SemanticFeature, SetIdT
5258
from memmachine.semantic_memory.storage.storage_base import (
5359
FeatureIdT,
@@ -616,35 +622,11 @@ def _compile_feature_comparison_expr(
616622
) -> ColumnElement[bool] | None:
617623
column, is_metadata = self._resolve_feature_field(table, expr.field)
618624

619-
if column is None:
620-
logger.warning("Unsupported feature filter field: %s", expr.field)
621-
return None
622-
623-
if expr.op == "=":
624-
value = expr.value
625-
if isinstance(value, list):
626-
raise ValueError("'=' comparison cannot accept list values")
627-
if is_metadata:
628-
value = self._normalize_metadata_value(value)
629-
return column == value
630-
return column == value
631-
632-
if expr.op == "in":
633-
if not isinstance(expr.value, list):
634-
raise ValueError("IN comparison requires a list of values")
635-
636-
values = expr.value
637-
if is_metadata:
638-
values = [self._normalize_metadata_value(v) for v in values]
639-
return column.in_(values)
640-
641-
if expr.op == "is_null":
642-
return column.is_(None)
643-
644-
if expr.op == "is_not_null":
645-
return column.is_not(None)
646-
647-
raise ValueError(f"Unsupported operator: {expr.op}")
625+
return parse_sql_filter(
626+
column=column,
627+
is_metadata=is_metadata,
628+
expr=expr,
629+
)
648630

649631
def _compile_feature_filter_expr(
650632
self,
@@ -668,19 +650,13 @@ def _compile_feature_filter_expr(
668650

669651
raise TypeError(f"Unsupported filter expression type: {type(expr)!r}")
670652

671-
@staticmethod
672-
def _normalize_metadata_value(
673-
value: FilterablePropertyValue | list[FilterablePropertyValue],
674-
) -> str:
675-
if isinstance(value, bool):
676-
return "true" if value else "false"
677-
return "" if value is None else str(value)
678-
679653
@staticmethod
680654
def _resolve_feature_field(
681655
table: type[Feature],
682656
field: str,
683-
) -> tuple[Any, bool] | tuple[None, bool]:
657+
) -> (
658+
tuple[MappedColumn[Any] | InstrumentedAttribute[Any], bool] | tuple[None, bool]
659+
):
684660
normalized = field
685661
field_mapping = {
686662
"set_id": table.set_id,

0 commit comments

Comments
 (0)