Skip to content

Commit 4f1d246

Browse files
committed
feat: support nebular database
1 parent fc3f7fa commit 4f1d246

File tree

2 files changed

+107
-32
lines changed

2 files changed

+107
-32
lines changed

examples/basic_modules/nebular_example.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,50 @@
1+
import json
2+
import os
3+
14
from datetime import datetime
25

6+
import numpy as np
7+
8+
from dotenv import load_dotenv
9+
310
from memos.configs.embedder import EmbedderConfigFactory
411
from memos.configs.graph_db import GraphDBConfigFactory
512
from memos.embedders.factory import EmbedderFactory
613
from memos.graph_dbs.factory import GraphStoreFactory
714
from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata
815

916

17+
load_dotenv()
18+
1019
embedder_config = EmbedderConfigFactory.model_validate(
1120
{
1221
"backend": "universal_api",
1322
"config": {
1423
"provider": "openai",
15-
"api_key": "sk-pCuT1CqW4XfPmOZsZGp0ugF8xd4uU61nOrVm4JpWCz1dmWaT",
24+
"api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"),
1625
"model_name_or_path": "text-embedding-3-large",
17-
"base_url": "http://123.129.219.111:3000/v1",
26+
"base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
1827
},
1928
}
2029
)
2130
embedder = EmbedderFactory.from_config(embedder_config)
2231

2332

2433
def embed_memory_item(memory: str) -> list[float]:
25-
return embedder.embed([memory])[0]
34+
embedding = embedder.embed([memory])[0]
35+
embedding_np = np.array(embedding, dtype=np.float32)
36+
embedding_list = embedding_np.tolist()
37+
return embedding_list
2638

2739

2840
def example_multi_db(db_name: str = "paper"):
2941
# Step 1: Build factory config
3042
config = GraphDBConfigFactory(
3143
backend="nebular",
3244
config={
33-
"hosts": ["106.14.142.60:9669", "120.55.160.164:9669", "106.15.38.5:9669"],
34-
"user_name": "root",
35-
"password": "Nebula123",
45+
"hosts": json.loads(os.getenv("NEBULAR_HOSTS", "localhost")),
46+
"user_name": os.getenv("NEBULAR_USER", "root"),
47+
"password": os.getenv("NEBULAR_PASSWORD", "xxxxxx"),
3648
"space": db_name,
3749
"auto_create": True,
3850
"embedding_dimension": 3072,
@@ -81,17 +93,17 @@ def example_shared_db(db_name: str = "shared-traval-group"):
8193
Multiple users' data in the same Neo4j DB with user_name as a tag.
8294
"""
8395
# users
84-
user_list = ["travel_member_alice", "travel_member_bob"]
96+
user_list = ["root"]
8597

8698
for user_name in user_list:
8799
# Step 1: Build factory config
88100
config = GraphDBConfigFactory(
89101
backend="nebular",
90102
config={
91-
"hosts": ["106.14.142.60:9669", "120.55.160.164:9669", "106.15.38.5:9669"],
92-
"password": "Nebula123",
103+
"hosts": json.loads(os.getenv("NEBULAR_HOSTS", "localhost")),
104+
"user_name": os.getenv("NEBULAR_USER", "root"),
105+
"password": os.getenv("NEBULAR_PASSWORD", "xxxxxx"),
93106
"space": db_name,
94-
"user_name": "root",
95107
"auto_create": True,
96108
"embedding_dimension": 3072,
97109
"use_multi_db": False,
@@ -162,7 +174,6 @@ def example_shared_db(db_name: str = "shared-traval-group"):
162174

163175
# Link concept to topic
164176
graph.add_edge(source_id=concept.id, target_id=topic.id, type="RELATE_TO")
165-
166177
print(f"[INFO] Added nodes for {user_name}")
167178

168179
# Step 5: Query and print ALL for verification
@@ -176,10 +187,10 @@ def example_shared_db(db_name: str = "shared-traval-group"):
176187
config_alice = GraphDBConfigFactory(
177188
backend="nebular",
178189
config={
179-
"hosts": ["106.14.142.60:9669", "120.55.160.164:9669", "106.15.38.5:9669"],
180-
"password": "Nebula123",
190+
"hosts": json.loads(os.getenv("NEBULAR_HOSTS", "localhost")),
191+
"user_name": os.getenv("NEBULAR_USER", "root"),
192+
"password": os.getenv("NEBULAR_PASSWORD", "xxxxxx"),
181193
"space": db_name,
182-
"user_name": "root",
183194
"embedding_dimension": 3072,
184195
},
185196
)

src/memos/graph_dbs/nebular.py

Lines changed: 82 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import json
2-
31
from datetime import datetime
42
from typing import Any, Literal
53

@@ -16,15 +14,20 @@ def _escape_str(value: str) -> str:
1614
return value.replace('"', '\\"')
1715

1816

19-
def _format_value(val: Any) -> str:
17+
def _format_value(val: Any, key: str = "") -> str:
2018
if isinstance(val, str):
2119
return f'"{_escape_str(val)}"'
22-
elif isinstance(val, int | float):
20+
elif isinstance(val, (int | float)):
2321
return str(val)
2422
elif isinstance(val, datetime):
2523
return f'datetime("{val.isoformat()}")'
2624
elif isinstance(val, list):
27-
return json.dumps(val)
25+
if key == "embedding":
26+
dim = len(val)
27+
joined = ",".join(str(float(x)) for x in val)
28+
return f"VECTOR<{dim}, FLOAT>([{joined}])"
29+
else:
30+
return f"[{', '.join(_format_value(v) for v in val)}]"
2831
elif val is None:
2932
return "NULL"
3033
else:
@@ -107,17 +110,18 @@ def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None:
107110
"""
108111
Insert or update a Memory node in NebulaGraph.
109112
"""
113+
if not self.config.use_multi_db and self.config.user_name:
114+
metadata["user_name"] = self.config.user_name
115+
110116
now = datetime.utcnow()
111117
metadata = metadata.copy()
112118
metadata.setdefault("created_at", now)
113119
metadata.setdefault("updated_at", now)
114-
metadata.pop("embedding")
115120
metadata["node_type"] = metadata.pop("type")
116121
metadata["id"] = id
117122
metadata["memory"] = memory
118123

119-
print("metadata: ", metadata)
120-
properties = ", ".join(f"{k}: {_format_value(v)}" for k, v in metadata.items())
124+
properties = ", ".join(f"{k}: {_format_value(v, k)}" for k, v in metadata.items())
121125
gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
122126

123127
try:
@@ -151,8 +155,6 @@ def add_edge(self, source_id: str, target_id: str, type: str):
151155
MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}})
152156
INSERT (a) -[e@{type} {props}]-> (b)
153157
'''
154-
155-
print(f"[add_edge] Executing NGQL:\n{insert_stmt}")
156158
try:
157159
self.client.execute(insert_stmt)
158160
except Exception:
@@ -213,6 +215,29 @@ def search_by_embedding(
213215
status: str | None = None,
214216
threshold: float | None = None,
215217
) -> list[dict]:
218+
"""
219+
Retrieve node IDs based on vector similarity.
220+
221+
Args:
222+
vector (list[float]): The embedding vector representing query semantics.
223+
top_k (int): Number of top similar nodes to retrieve.
224+
scope (str, optional): Memory type filter (e.g., 'WorkingMemory', 'LongTermMemory').
225+
status (str, optional): Node status filter (e.g., 'active', 'archived').
226+
If provided, restricts results to nodes with matching status.
227+
threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
228+
229+
Returns:
230+
list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
231+
232+
Notes:
233+
- This method uses Neo4j native vector indexing to search for similar nodes.
234+
- If scope is provided, it restricts results to nodes with matching memory_type.
235+
- If 'status' is provided, only nodes with the matching status will be returned.
236+
- If threshold is provided, only results with score >= threshold will be returned.
237+
- Typical use case: restrict to 'status = activated' to avoid
238+
matching archived or merged nodes.
239+
"""
240+
216241
raise NotImplementedError
217242

218243
def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]:
@@ -243,7 +268,50 @@ def clear(self) -> None:
243268
"""
244269

245270
def export_graph(self) -> dict[str, Any]:
246-
raise NotImplementedError
271+
"""
272+
Export all graph nodes and edges in a structured form.
273+
274+
Returns:
275+
{
276+
"nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ],
277+
"edges": [ { "source": ..., "target": ..., "type": ... }, ... ]
278+
}
279+
"""
280+
node_query = "MATCH (n@Memory)"
281+
edge_query = "MATCH (a@Memory)-[r]->(b@Memory)"
282+
283+
if not self.config.use_multi_db and self.config.user_name:
284+
username = self.config.user_name
285+
node_query += f' WHERE n.user_name = "{username}"'
286+
edge_query += f' WHERE r.user_name = "{username}"'
287+
288+
try:
289+
full_node_query = f"{node_query} RETURN n"
290+
node_result = self.client.execute(full_node_query)
291+
nodes = []
292+
for row in node_result:
293+
node_wrapper = row.values()[0].as_node()
294+
props = node_wrapper.get_properties()
295+
296+
metadata = {key: self._parse_node(val) for key, val in props.items()}
297+
298+
memory = metadata.get("memory", "")
299+
300+
nodes.append({"id": node_wrapper.get_id(), "memory": memory, "metadata": metadata})
301+
except Exception as e:
302+
raise RuntimeError(f"[EXPORT GRAPH - NODES] Exception: {e}") from e
303+
304+
try:
305+
full_edge_query = f"{edge_query} RETURN a.id AS source, b.id AS target"
306+
edge_result = self.client.execute(full_edge_query)
307+
edges = [
308+
{"source": row.values()[0].value, "target": row.values()[1].value}
309+
for row in edge_result
310+
]
311+
except Exception as e:
312+
raise RuntimeError(f"[EXPORT GRAPH - EDGES] Exception: {e}") from e
313+
314+
return {"nodes": nodes, "edges": edges}
247315

248316
def import_graph(self, data: dict[str, Any]) -> None:
249317
raise NotImplementedError
@@ -263,6 +331,7 @@ def _ensure_database_exists(self):
263331
NODE Memory (:MemoryTag {
264332
id STRING,
265333
memory STRING,
334+
user_name STRING,
266335
created_at STRING,
267336
updated_at STRING,
268337
status STRING,
@@ -279,19 +348,15 @@ def _ensure_database_exists(self):
279348
usage LIST<STRING>,
280349
background STRING,
281350
hierarchy_level STRING,
351+
embedding VECTOR<3072, FLOAT>,
282352
PRIMARY KEY(id)
283353
}),
284354
EDGE RELATE_TO (Memory) -[{user_name STRING}]-> (Memory)
285355
}
286356
"""
287357
create_graph = "CREATE GRAPH IF NOT EXISTS memory_graph TYPED MemoryGraphType"
288358
set_graph_working = "SESSION SET GRAPH memory_graph"
289-
290-
drop_graph = "DROP GRAPH memory_graph"
291-
drop_type = "DROP GRAPH TYPE MemoryGraphType"
292359
try:
293-
self.client.execute(drop_graph)
294-
self.client.execute(drop_type)
295360
self.client.execute(create_tag)
296361
self.client.execute(create_graph)
297362
self.client.execute(set_graph_working)
@@ -314,5 +379,4 @@ def _index_exists(self, index_name: str) -> bool:
314379
"""raise NotImplementedError"""
315380

316381
def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]:
317-
""" """
318-
raise NotImplementedError
382+
return node_data

0 commit comments

Comments
 (0)