Skip to content

Commit 80df31c

Browse files
authored
fix: add metadata filter and get by metadata bug (#184)
* fix: add metadata filter and get by metadata bug * fix: add metadata filter and get by metadata bug
1 parent d3c6238 commit 80df31c

File tree

1 file changed

+42
-3
lines changed

1 file changed

+42
-3
lines changed

src/memos/graph_dbs/nebular.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def _prepare_node_metadata(metadata: dict[str, Any]) -> dict[str, Any]:
3838
- Convert embedding to list of float if present.
3939
"""
4040
now = datetime.utcnow().isoformat()
41+
metadata["node_type"] = metadata.pop("type")
4142

4243
# Fill timestamps if missing
4344
metadata.setdefault("created_at", now)
@@ -51,6 +52,44 @@ def _prepare_node_metadata(metadata: dict[str, Any]) -> dict[str, Any]:
5152
return metadata
5253

5354

55+
def _metadata_filter(metadata: dict[str, Any]) -> dict[str, Any]:
56+
"""
57+
Filter and validate metadata dictionary against the Memory node schema.
58+
- Removes keys not in schema.
59+
- Warns if required fields are missing.
60+
"""
61+
62+
allowed_fields = {
63+
"id",
64+
"memory",
65+
"user_name",
66+
"user_id",
67+
"session_id",
68+
"status",
69+
"key",
70+
"confidence",
71+
"tags",
72+
"created_at",
73+
"updated_at",
74+
"memory_type",
75+
"sources",
76+
"source",
77+
"node_type",
78+
"visibility",
79+
"usage",
80+
"background",
81+
"embedding",
82+
}
83+
84+
missing_fields = allowed_fields - metadata.keys()
85+
if missing_fields:
86+
logger.warning(f"Metadata missing required fields: {sorted(missing_fields)}")
87+
88+
filtered_metadata = {k: v for k, v in metadata.items() if k in allowed_fields}
89+
90+
return filtered_metadata
91+
92+
5493
def _escape_str(value: str) -> str:
5594
return value.replace('"', '\\"')
5695

@@ -281,6 +320,7 @@ def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None:
281320
if "embedding" in metadata and isinstance(metadata["embedding"], list):
282321
metadata["embedding"] = _normalize(metadata["embedding"])
283322

323+
metadata = _metadata_filter(metadata)
284324
properties = ", ".join(f"{k}: {_format_value(v, k)}" for k, v in metadata.items())
285325
gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
286326

@@ -830,8 +870,7 @@ def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]:
830870

831871
def _escape_value(value):
832872
if isinstance(value, str):
833-
escaped = value.replace('"', '\\"')
834-
return f'"{escaped}"'
873+
return f'"{value}"'
835874
elif isinstance(value, list):
836875
return "[" + ", ".join(_escape_value(v) for v in value) + "]"
837876
else:
@@ -850,7 +889,7 @@ def _escape_value(value):
850889
elif op == "in":
851890
where_clauses.append(f"n.{field} IN {escaped_value}")
852891
elif op == "contains":
853-
where_clauses.append(f"ANY(x IN n.{field} WHERE x = {escaped_value})")
892+
where_clauses.append(f"size(filter(n.{field}, t -> t IN {escaped_value})) > 0")
854893
elif op == "starts_with":
855894
where_clauses.append(f"n.{field} STARTS WITH {escaped_value}")
856895
elif op == "ends_with":

0 commit comments

Comments
 (0)