Skip to content

Commit 7f79542

Browse files
committed
feat: support nebular database
1 parent f12d361 commit 7f79542

File tree

2 files changed

+260
-10
lines changed

2 files changed

+260
-10
lines changed

src/memos/graph_dbs/nebular.py

Lines changed: 75 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def __init__(self, config: NebulaGraphDBConfig):
113113
username=config.get("user_name"),
114114
password=config.get("password"),
115115
)
116-
self.db_name = config.db_name
116+
self.db_name = config.space
117117
self.space = config.get("space")
118118
self.user_name = config.user_name
119119
self.system_db_name = "system" if config.use_multi_db else config.space
@@ -336,7 +336,6 @@ def edge_exists(
336336
query += "\nRETURN r"
337337

338338
# Run the Cypher query
339-
print("\n ======> query: ", query)
340339
result = self.client.execute(query)
341340
return result.one_or_none().values() is not None
342341

@@ -661,7 +660,40 @@ def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]:
661660
- Supports structured querying such as tag/category/importance/time filtering.
662661
- Can be used for faceted recall or prefiltering before embedding rerank.
663662
"""
664-
raise NotImplementedError
663+
where_clauses = []
664+
for _i, f in enumerate(filters):
665+
field = f["field"]
666+
op = f.get("op", "=")
667+
value = f["value"]
668+
669+
# Build WHERE clause
670+
if op == "=":
671+
where_clauses.append(f"n.{field} = {value}")
672+
elif op == "in":
673+
where_clauses.append(f"n.{field} IN {value}")
674+
elif op == "contains":
675+
where_clauses.append(f"ANY(x IN {value} WHERE x IN n.{field})")
676+
elif op == "starts_with":
677+
where_clauses.append(f"n.{field} STARTS WITH {value}")
678+
elif op == "ends_with":
679+
where_clauses.append(f"n.{field} ENDS WITH {value}")
680+
elif op in [">", ">=", "<", "<="]:
681+
where_clauses.append(f"n.{field} {op} {value}")
682+
else:
683+
raise ValueError(f"Unsupported operator: {op}")
684+
685+
if not self.config.use_multi_db and self.config.user_name:
686+
where_clauses.append(f"n.user_name = '{self.config.user_name}'")
687+
688+
where_str = " AND ".join(where_clauses)
689+
query = f"MATCH (n@Memory) WHERE {where_str} RETURN n.id AS id"
690+
691+
try:
692+
print("\n==========> query:\n", query)
693+
result = self.client.execute(query)
694+
return [record["id"].value for record in result]
695+
except Exception as e:
696+
logger.error(f"Failed to get metadata: {e}")
665697

666698
def get_grouped_counts(
667699
self,
@@ -827,7 +859,6 @@ def import_graph(self, data: dict[str, Any]) -> None:
827859
'''
828860
self.client.execute(edge_gql)
829861

830-
# TODO
831862
def get_all_memory_items(self, scope: str) -> list[dict]:
832863
"""
833864
Retrieve all memory items of a specific memory_type.
@@ -838,7 +869,24 @@ def get_all_memory_items(self, scope: str) -> list[dict]:
838869
Returns:
839870
list[dict]: Full list of memory items under this scope.
840871
"""
841-
raise NotImplementedError
872+
if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory"}:
873+
raise ValueError(f"Unsupported memory type scope: {scope}")
874+
875+
where_clause = f"WHERE n.memory_type = '{scope}'"
876+
877+
if not self.config.use_multi_db and self.config.user_name:
878+
where_clause += f" AND n.user_name = '{self.config.user_name}'"
879+
880+
query = f"""
881+
MATCH (n@Memory)
882+
{where_clause}
883+
RETURN n
884+
"""
885+
try:
886+
results = self.client.execute(query)
887+
return [self._parse_node(record["n"]) for record in results]
888+
except Exception as e:
889+
logger.error(f"Failed to get memories: {e}")
842890

843891
# TODO
844892
def get_structure_optimization_candidates(self, scope: str) -> list[dict]:
@@ -847,16 +895,35 @@ def get_structure_optimization_candidates(self, scope: str) -> list[dict]:
847895
- Isolated nodes, nodes with empty background, or nodes with exactly one child.
848896
- Plus: the child of any parent node that has exactly one child.
849897
"""
850-
raise NotImplementedError
898+
where_clause = f"""
899+
WHERE n.memory_type = '{scope}'
900+
AND n.status = 'activated'
901+
AND NOT ( (n)-[r@PARENT]->() OR ()-[r@PARENT]->(n) )
902+
"""
903+
904+
if not self.config.use_multi_db and self.config.user_name:
905+
where_clause += f" AND n.user_name = '{self.config.user_name}'"
906+
907+
query = f"""
908+
MATCH (n@Memory)
909+
{where_clause}
910+
RETURN n.id AS id, n AS node
911+
"""
912+
try:
913+
results = self.client.execute(query)
914+
return [
915+
self._parse_node({"id": record["id"], **dict(record["node"])}) for record in results
916+
]
917+
except Exception as e:
918+
logger.error(f"Failed : {e}")
851919

852-
# TODO
853920
def drop_database(self) -> None:
854921
"""
855922
Permanently delete the entire database this instance is using.
856923
WARNING: This operation is destructive and cannot be undone.
857924
"""
858925
if self.config.use_multi_db:
859-
self.client.execute(f"DROP DATABASE {self.db_name} IF EXISTS")
926+
self.client.execute(f"DROP GRAPH {self.db_name}")
860927
logger.info(f"Database '{self.db_name}' has been dropped.")
861928
else:
862929
raise ValueError(
@@ -992,7 +1059,6 @@ def _create_basic_property_indexes(self) -> None:
9921059
"""
9931060
raise NotImplementedError
9941061

995-
# TODO
9961062
def _index_exists(self, index_name: str) -> bool:
9971063
"""
9981064
Check if an index with the given name exists.

tests/graph_dbs/test_nebular.py

Lines changed: 185 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
"hosts": json.loads(os.getenv("NEBULAR_HOSTS", "localhost")),
3030
"user_name": os.getenv("NEBULAR_USER", "root"),
3131
"password": os.getenv("NEBULAR_PASSWORD", "xxxxxx"),
32-
"space": "test_memory_count",
32+
"space": "memory_graph",
3333
"auto_create": True,
3434
"embedding_dimension": 3072,
3535
"use_multi_db": False,
@@ -224,3 +224,187 @@ def test_get_edges():
224224
assert edges[0]["from"] == source.id
225225
assert edges[0]["to"] == target.id
226226
assert edges[0]["type"] == "PARENT"
227+
228+
229+
def test_get_all_memory_items():
230+
graph = GraphStoreFactory.from_config(
231+
GraphDBConfigFactory(
232+
backend="nebular",
233+
config=nebular_config,
234+
)
235+
)
236+
graph.clear()
237+
238+
# Insert 2 WorkingMemory items
239+
for i in range(2):
240+
mem = TextualMemoryItem(
241+
memory=f"Memory {i}",
242+
metadata=TreeNodeTextualMemoryMetadata(
243+
memory_type="WorkingMemory",
244+
key="Research Topic",
245+
hierarchy_level="topic",
246+
type="fact",
247+
memory_time="2024-01-01",
248+
status="activated",
249+
visibility="public",
250+
updated_at=now,
251+
embedding=embed_memory_item(f"Memory {i}"),
252+
),
253+
)
254+
graph.add_node(mem.id, mem.memory, mem.metadata.model_dump(exclude_none=True))
255+
256+
# Retrieve all memory items of type WorkingMemory
257+
items = graph.get_all_memory_items("WorkingMemory")
258+
assert len(items) == 2
259+
assert all(item["properties"]["memory_type"] == "WorkingMemory" for item in items)
260+
261+
262+
def test_get_structure_optimization_candidates():
263+
graph = GraphStoreFactory.from_config(
264+
GraphDBConfigFactory(
265+
backend="nebular",
266+
config=nebular_config,
267+
)
268+
)
269+
graph.clear()
270+
271+
# Insert one isolated node (no parent or child)
272+
mem = TextualMemoryItem(
273+
memory="Isolated memory",
274+
metadata=TreeNodeTextualMemoryMetadata(
275+
memory_type="LongTermMemory",
276+
key="Research Topic",
277+
hierarchy_level="topic",
278+
type="fact",
279+
memory_time="2024-01-01",
280+
status="activated",
281+
visibility="public",
282+
updated_at=now,
283+
embedding=embed_memory_item("Isolated memory"),
284+
),
285+
)
286+
graph.add_node(mem.id, mem.memory, mem.metadata.model_dump(exclude_none=True))
287+
288+
# Insert one node with empty background (and no edges)
289+
mem2 = TextualMemoryItem(
290+
memory="Empty background memory",
291+
metadata=TreeNodeTextualMemoryMetadata(
292+
memory_type="LongTermMemory",
293+
key="Research Topic",
294+
hierarchy_level="topic",
295+
type="fact",
296+
memory_time="2024-01-01",
297+
status="activated",
298+
visibility="public",
299+
updated_at=now,
300+
embedding=embed_memory_item("Empty background memory"),
301+
),
302+
)
303+
graph.add_node(mem2.id, mem2.memory, mem2.metadata.model_dump(exclude_none=True))
304+
305+
# Find optimization candidates
306+
candidates = graph.get_structure_optimization_candidates("LongTermMemory")
307+
print("Optimization candidates:", candidates)
308+
assert any("Isolated memory" in c["memory"] for c in candidates)
309+
assert any("Empty background memory" in c["memory"] for c in candidates)
310+
311+
312+
def test_drop_database():
313+
config = GraphDBConfigFactory(
314+
backend="nebular",
315+
config=nebular_config,
316+
)
317+
graph = GraphStoreFactory.from_config(config)
318+
319+
# Create a dummy node
320+
mem = TextualMemoryItem(
321+
memory="Temp for drop DB",
322+
metadata=TreeNodeTextualMemoryMetadata(
323+
memory_type="LongTermMemory",
324+
key="Research Topic",
325+
hierarchy_level="topic",
326+
type="fact",
327+
memory_time="2024-01-01",
328+
status="activated",
329+
visibility="public",
330+
updated_at=now,
331+
embedding=embed_memory_item("Temp for drop DB"),
332+
),
333+
)
334+
graph.add_node(mem.id, mem.memory, mem.metadata.model_dump(exclude_none=True))
335+
336+
# Drop the database
337+
graph.drop_database()
338+
339+
# Attempting any operation afterward should raise an error or fail (optional)
340+
try:
341+
_ = graph.get_all_memory_items("WorkingMemory")
342+
except Exception as e:
343+
print("Expected exception after DB drop:", str(e))
344+
assert "Current working graph not found" in str(e)
345+
346+
347+
def test_get_by_metadata():
348+
config = GraphDBConfigFactory(
349+
backend="nebular",
350+
config=nebular_config,
351+
)
352+
graph = GraphStoreFactory.from_config(config)
353+
graph.clear()
354+
355+
mem1 = TextualMemoryItem(
356+
memory="AI for science",
357+
metadata=TreeNodeTextualMemoryMetadata(
358+
memory_type="LongTermMemory",
359+
key="AI Science",
360+
confidence=92.5,
361+
tags=["AI", "science"],
362+
hierarchy_level="topic",
363+
type="fact",
364+
memory_time="2024-01-01",
365+
status="activated",
366+
visibility="public",
367+
updated_at=now,
368+
embedding=embed_memory_item("AI for science"),
369+
),
370+
)
371+
mem2 = TextualMemoryItem(
372+
memory="Neurosymbolic reasoning",
373+
metadata=TreeNodeTextualMemoryMetadata(
374+
memory_type="LongTermMemory",
375+
key="Neurosymbolic",
376+
tags=["symbolic", "reasoning"],
377+
confidence=88.0,
378+
hierarchy_level="topic",
379+
type="fact",
380+
memory_time="2024-01-01",
381+
status="activated",
382+
visibility="public",
383+
updated_at=now,
384+
embedding=embed_memory_item("Neurosymbolic reasoning"),
385+
),
386+
)
387+
graph.add_node(mem1.id, mem1.memory, mem1.metadata.model_dump(exclude_none=True))
388+
graph.add_node(mem2.id, mem2.memory, mem2.metadata.model_dump(exclude_none=True))
389+
390+
# Exact match filter
391+
result_ids = graph.get_by_metadata([{"field": "key", "op": "=", "value": '"AI Science"'}])
392+
assert mem1.id in result_ids
393+
assert mem2.id not in result_ids
394+
395+
# Confidence filter
396+
result_ids = graph.get_by_metadata([{"field": "confidence", "op": ">=", "value": 90.0}])
397+
assert mem1.id in result_ids
398+
assert mem2.id not in result_ids
399+
400+
# Tag contains filter TODO
401+
result_ids = graph.get_by_metadata([{"field": "tags", "op": "contains", "value": '["AI"]'}])
402+
assert mem1.id in result_ids
403+
assert mem2.id not in result_ids
404+
405+
# In set filter
406+
result_ids = graph.get_by_metadata(
407+
[{"field": "key", "op": "in", "value": '["AI Science", "Neurosymbolic"]'}]
408+
)
409+
assert mem1.id in result_ids
410+
assert mem2.id in result_ids

0 commit comments

Comments
 (0)