Skip to content

Commit 9023e97

Browse files
committed
feat: support nebular database
1 parent 832896f commit 9023e97

File tree

2 files changed

+111
-7
lines changed

2 files changed

+111
-7
lines changed

examples/basic_modules/nebular_example.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def example_shared_db(db_name: str = "shared-traval-group"):
180180
print("\n=== Export entire DB (for verification, includes ALL users) ===")
181181
graph = GraphStoreFactory.from_config(config)
182182
all_graph_data = graph.export_graph()
183-
print(all_graph_data)
183+
print(str(all_graph_data)[:1000])
184184

185185
# Step 6: Search for alice's data only
186186
print("\n=== Search for travel_member_alice ===")
@@ -195,9 +195,9 @@ def example_shared_db(db_name: str = "shared-traval-group"):
195195
},
196196
)
197197
graph_alice = GraphStoreFactory.from_config(config_alice)
198-
nodes = graph_alice.search_by_embedding(vector=embed_memory_item("travel itinerary"), top_k=1)
198+
nodes = graph_alice.search_by_embedding(vector=embed_memory_item("travel itinerary"), top_k=3)
199199
for node in nodes:
200-
print(graph_alice.get_node(node["id"]))
200+
print(str(graph_alice.get_node(node["id"]))[:1000])
201201

202202

203203
if __name__ == "__main__":

src/memos/graph_dbs/nebular.py

Lines changed: 108 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from datetime import datetime
22
from typing import Any, Literal
33

4+
from nebulagraph_python.value_wrapper import ValueWrapper
5+
46
from memos.configs.graph_db import NebulaGraphDBConfig
57
from memos.dependency import require_python_package
68
from memos.graph_dbs.base import BaseGraphDB
@@ -170,7 +172,35 @@ def edge_exists(
170172

171173
# Graph Query & Reasoning
172174
def get_node(self, id: str) -> dict[str, Any] | None:
173-
raise NotImplementedError
175+
"""
176+
Retrieve a Memory node by its unique ID.
177+
178+
Args:
179+
id (str): Node ID (Memory.id)
180+
181+
Returns:
182+
dict: Node properties as key-value pairs, or None if not found.
183+
"""
184+
gql = f"""
185+
USE memory_graph
186+
MATCH (v {{id: '{id}'}})
187+
RETURN v
188+
"""
189+
190+
try:
191+
result = self.client.execute(gql)
192+
record = result.one_or_none()
193+
if record is None:
194+
return None
195+
196+
node_wrapper = record["v"].as_node()
197+
props = node_wrapper.get_properties()
198+
199+
return {key: self._parse_node(val) for key, val in props.items()}
200+
201+
except Exception as e:
202+
logger.error(f"[get_node] Failed to retrieve node '{id}': {e}")
203+
return None
174204

175205
def get_nodes(self, ids: list[str]) -> list[dict[str, Any]]:
176206
raise NotImplementedError
@@ -237,8 +267,49 @@ def search_by_embedding(
237267
- Typical use case: restrict to 'status = activated' to avoid
238268
matching archived or merged nodes.
239269
"""
270+
dim = len(vector)
271+
vector_str = ",".join(f"{float(x)}" for x in vector)
272+
gql_vector = f"VECTOR<{dim}, FLOAT>([{vector_str}])"
273+
274+
where_clauses = []
275+
if scope:
276+
where_clauses.append(f'n.memory_type == "{scope}"')
277+
if status:
278+
where_clauses.append(f'n.status == "{status}"')
279+
if not self.config.use_multi_db and self.config.user_name:
280+
where_clauses.append(f'n.user_name == "{self.config.user_name}"')
240281

241-
raise NotImplementedError
282+
where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
283+
284+
gql = f"""
285+
USE memory_graph
286+
MATCH (n@Memory)
287+
{where_clause}
288+
ORDER BY euclidean(n.embedding, {gql_vector}) ASC
289+
APPROXIMATE
290+
LIMIT {top_k}
291+
OPTIONS {{ METRIC: L2, TYPE: IVF, NPROBE: 8 }}
292+
RETURN n.id AS id, euclidean(n.embedding, {gql_vector}) AS score
293+
"""
294+
295+
try:
296+
result = self.client.execute(gql)
297+
except Exception as e:
298+
logger.error(f"[search_by_embedding] Query failed: {e}")
299+
return []
300+
301+
try:
302+
output = []
303+
for row in result:
304+
values = row.values()
305+
id_val = values[0].as_string()
306+
score_val = values[1].as_double()
307+
if threshold is None or score_val <= threshold:
308+
output.append({"id": id_val, "score": score_val})
309+
return output
310+
except Exception as e:
311+
logger.error(f"[search_by_embedding] Result parse failed: {e}")
312+
return []
242313

243314
def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]:
244315
raise NotImplementedError
@@ -356,10 +427,23 @@ def _ensure_database_exists(self):
356427
"""
357428
create_graph = "CREATE GRAPH IF NOT EXISTS memory_graph TYPED MemoryGraphType"
358429
set_graph_working = "SESSION SET GRAPH memory_graph"
430+
create_vector_index = """
431+
CREATE VECTOR INDEX IF NOT EXISTS memory_vector_index
432+
ON NODE Memory::embedding
433+
OPTIONS {
434+
DIM: 3072,
435+
METRIC: L2,
436+
TYPE: IVF,
437+
NLIST: 100,
438+
TRAINSIZE: 1000
439+
}
440+
FOR memory_graph
441+
"""
359442
try:
360443
self.client.execute(create_tag)
361444
self.client.execute(create_graph)
362445
self.client.execute(set_graph_working)
446+
self.client.execute(create_vector_index)
363447
logger.info("✅ Graph `memory_graph` is now the working graph.")
364448
except Exception as e:
365449
logger.error(f"❌ Failed to create tag: {e}")
@@ -378,5 +462,25 @@ def _create_basic_property_indexes(self) -> None:
378462
def _index_exists(self, index_name: str) -> bool:
379463
"""raise NotImplementedError"""
380464

381-
def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]:
382-
return node_data
465+
def _parse_node(self, value: ValueWrapper) -> Any:
466+
if value is None or value.is_null():
467+
return None
468+
try:
469+
primitive_value = value.cast_primitive()
470+
except Exception as e:
471+
logger.warning(f"cast_primitive failed for value: {value}, error: {e}")
472+
try:
473+
primitive_value = value.cast()
474+
except Exception as e2:
475+
logger.warning(f"cast failed for value: {value}, error: {e2}")
476+
return str(value)
477+
478+
if isinstance(primitive_value, ValueWrapper):
479+
return self._parse_node(primitive_value)
480+
481+
if isinstance(primitive_value, list):
482+
return [
483+
self._parse_node(v) if isinstance(v, ValueWrapper) else v for v in primitive_value
484+
]
485+
486+
return primitive_value

0 commit comments

Comments
 (0)