Skip to content

Commit 084b241

Browse files
committed
feat: fix multi-tanant bug for nebula; add create index; fix some functions in nebula; fix example
1 parent 7f79542 commit 084b241

File tree

3 files changed

+131
-57
lines changed

3 files changed

+131
-57
lines changed

examples/basic_modules/nebular_example.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,21 @@
1616

1717
load_dotenv()
1818

19+
20+
def show(nebular_data):
21+
from memos.configs.graph_db import Neo4jGraphDBConfig
22+
from memos.graph_dbs.neo4j import Neo4jGraphDB
23+
24+
tree_config = Neo4jGraphDBConfig.from_json_file("../../examples/data/config/neo4j_config.json")
25+
tree_config.use_multi_db = False
26+
tree_config.db_name = "nebular-show"
27+
tree_config.user_name = "nebular-show"
28+
29+
neo4j_db = Neo4jGraphDB(tree_config)
30+
neo4j_db.clear()
31+
neo4j_db.import_graph(nebular_data)
32+
33+
1934
embedder_config = EmbedderConfigFactory.model_validate(
2035
{
2136
"backend": "universal_api",
@@ -42,13 +57,13 @@ def example_multi_db(db_name: str = "paper"):
4257
config = GraphDBConfigFactory(
4358
backend="nebular",
4459
config={
45-
"hosts": json.loads(os.getenv("NEBULAR_HOSTS", "localhost")),
46-
"user_name": os.getenv("NEBULAR_USER", "root"),
60+
"uri": json.loads(os.getenv("NEBULAR_HOSTS", "localhost")),
61+
"user": os.getenv("NEBULAR_USER", "root"),
4762
"password": os.getenv("NEBULAR_PASSWORD", "xxxxxx"),
4863
"space": db_name,
64+
"use_multi_db": True,
4965
"auto_create": True,
5066
"embedding_dimension": 3072,
51-
"use_multi_db": True,
5267
},
5368
)
5469

@@ -100,13 +115,14 @@ def example_shared_db(db_name: str = "shared-traval-group"):
100115
config = GraphDBConfigFactory(
101116
backend="nebular",
102117
config={
103-
"hosts": json.loads(os.getenv("NEBULAR_HOSTS", "localhost")),
104-
"user_name": os.getenv("NEBULAR_USER", "root"),
118+
"uri": json.loads(os.getenv("NEBULAR_HOSTS", "localhost")),
119+
"user": os.getenv("NEBULAR_USER", "root"),
105120
"password": os.getenv("NEBULAR_PASSWORD", "xxxxxx"),
106121
"space": db_name,
122+
"user_name": user_name,
123+
"use_multi_db": False,
107124
"auto_create": True,
108125
"embedding_dimension": 3072,
109-
"use_multi_db": False,
110126
},
111127
)
112128

@@ -215,13 +231,14 @@ def run_user_session(
215231
config = GraphDBConfigFactory(
216232
backend="nebular",
217233
config={
218-
"hosts": json.loads(os.getenv("NEBULAR_HOSTS", "localhost")),
219-
"user_name": os.getenv("NEBULAR_USER", "root"),
234+
"uri": json.loads(os.getenv("NEBULAR_HOSTS", "localhost")),
235+
"user": os.getenv("NEBULAR_USER", "root"),
220236
"password": os.getenv("NEBULAR_PASSWORD", "xxxxxx"),
221237
"space": db_name,
238+
"user_name": user_name,
239+
"use_multi_db": False,
222240
"auto_create": True,
223241
"embedding_dimension": 3072,
224-
"use_multi_db": False,
225242
},
226243
)
227244
graph = GraphStoreFactory.from_config(config)
@@ -242,6 +259,7 @@ def run_user_session(
242259
memory_time="2024-01-01",
243260
status="activated",
244261
visibility="public",
262+
tags=["research", "rl"],
245263
updated_at=now,
246264
embedding=embed_memory_item(topic_text),
247265
),
@@ -299,8 +317,10 @@ def run_user_session(
299317
node = graph.get_node(r["id"])
300318
print("🔍 Search result:", node["memory"])
301319

320+
all_nodes = graph.export_graph()
321+
show(all_nodes)
322+
302323
# === Step 5: Tag-based neighborhood discovery ===
303-
# TODO
304324
neighbors = graph.get_neighbors_by_tag(["concept"], exclude_ids=[], top_k=2)
305325
print("📎 Tag-related nodes:", [neighbor["memory"] for neighbor in neighbors])
306326

src/memos/configs/graph_db.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
class BaseGraphDBConfig(BaseConfig):
1010
"""Base class for all graph database configurations."""
1111

12-
uri: str
12+
uri: str | list
1313
user: str
1414
password: str
1515

@@ -103,7 +103,7 @@ def validate_community(self):
103103
return self
104104

105105

106-
class NebulaGraphDBConfig(BaseConfig):
106+
class NebulaGraphDBConfig(BaseGraphDBConfig):
107107
"""
108108
NebulaGraph-specific configuration.
109109
@@ -121,8 +121,6 @@ class NebulaGraphDBConfig(BaseConfig):
121121
user_name = "alice"
122122
"""
123123

124-
password: str
125-
hosts: list[str] = Field(..., description="List of host:port strings for NebulaGraph servers")
126124
space: str = Field(
127125
..., description="The name of the target NebulaGraph space (like a database)"
128126
)

src/memos/graph_dbs/nebular.py

Lines changed: 99 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ def __init__(self, config: NebulaGraphDBConfig):
109109

110110
self.config = config
111111
self.client = NebulaClient(
112-
hosts=config.get("hosts"),
113-
username=config.get("user_name"),
112+
hosts=config.get("uri"),
113+
username=config.get("user"),
114114
password=config.get("password"),
115115
)
116116
self.db_name = config.space
@@ -132,19 +132,11 @@ def create_index(
132132
dimensions: int = 3072,
133133
index_name: str = "memory_vector_index",
134134
) -> None:
135-
create_vector_index = f"""
136-
CREATE VECTOR INDEX IF NOT EXISTS {index_name}
137-
ON NODE Memory::{vector_property}
138-
OPTIONS {{
139-
DIM: {dimensions},
140-
METRIC: L2,
141-
TYPE: IVF,
142-
NLIST: 100,
143-
TRAINSIZE: 1000
144-
}}
145-
FOR memory_graph
146-
"""
147-
self.client.execute(create_vector_index)
135+
# Create vector index if it doesn't exist
136+
if not self._vector_index_exists(index_name):
137+
self._create_vector_index(label, vector_property, dimensions, index_name)
138+
# Create indexes
139+
self._create_basic_property_indexes()
148140

149141
def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None:
150142
"""
@@ -468,27 +460,54 @@ def get_neighbors_by_tag(
468460
Returns:
469461
List of dicts with node details and overlap count.
470462
"""
471-
where_user = ""
463+
if not tags:
464+
return []
465+
466+
where_clauses = [
467+
'n.status = "activated"',
468+
'NOT (n.node_type = "reasoning")',
469+
'NOT (n.memory_type = "WorkingMemory")',
470+
]
471+
if exclude_ids:
472+
where_clauses.append(f"NOT (n.id IN {exclude_ids})")
473+
472474
if not self.config.use_multi_db and self.config.user_name:
473-
user_name = self.config.user_name
474-
where_user = f"AND n.user_name = {user_name}"
475+
where_clauses.append(f'n.user_name = "{self.config.user_name}"')
476+
477+
where_clause = " AND ".join(where_clauses)
478+
tag_list_literal = "[" + ", ".join(f'"{_escape_str(t)}"' for t in tags) + "]"
475479

476480
query = f"""
477-
MATCH (n@Memory)
478-
LET overlap_tags = [tag IN n.tags WHERE tag IN {tags}]
479-
WHERE NOT n.id IN {exclude_ids}
480-
AND n.status = 'activated'
481-
AND n.node_type <> 'reasoning'
482-
AND n.memory_type <> 'WorkingMemory'
483-
{where_user}
484-
AND size(overlap_tags) >= {min_overlap}
485-
RETURN n, size(overlap_tags) AS overlap_count
486-
ORDER BY overlap_count DESC
487-
LIMIT {top_k}
488-
"""
489-
print(query)
481+
LET tag_list = {tag_list_literal}
482+
483+
MATCH (n@Memory)
484+
WHERE {where_clause}
485+
RETURN n.id AS id,
486+
n.tags AS tags,
487+
n.user_name AS user_name,
488+
n.memory AS memory,
489+
n.status AS status,
490+
n.node_type AS node_type,
491+
n.memory_type AS memory_type,
492+
size( filter( n.tags, t -> t IN tag_list ) ) AS overlap_count
493+
ORDER BY overlap_count DESC
494+
LIMIT {top_k}
495+
"""
496+
490497
result = self.client.execute(query)
491-
return [self._parse_node(dict(record)) for record in result]
498+
neighbors = []
499+
for row in result:
500+
props = {col: self._parse_node(row[col]) for col in result.column_names}
501+
502+
node_tags = props.get("tags", [])
503+
overlap_tags = list(set(node_tags) & set(tags))
504+
505+
if len(overlap_tags) >= min_overlap:
506+
props["overlap_count"] = len(overlap_tags)
507+
neighbors.append(props)
508+
509+
neighbors.sort(key=lambda x: x["overlap_count"], reverse=True)
510+
return neighbors[:top_k]
492511

493512
def get_children_with_embeddings(self, id: str) -> list[dict[str, Any]]:
494513
where_user = ""
@@ -503,10 +522,16 @@ def get_children_with_embeddings(self, id: str) -> list[dict[str, Any]]:
503522
RETURN c.id AS id, c.embedding AS embedding, c.memory AS memory
504523
"""
505524
result = self.client.execute(query)
506-
return [
507-
{"id": r["id"].value, "embedding": r["embedding"].value, "memory": r["memory"].value}
508-
for r in result
509-
]
525+
children = []
526+
for r in result:
527+
children.append(
528+
{
529+
"id": self._parse_node(r["id"]),
530+
"embedding": self._parse_node(r["embedding"]),
531+
"memory": self._parse_node(r["memory"]),
532+
}
533+
)
534+
return children
510535

511536
def get_subgraph(
512537
self, center_id: str, depth: int = 2, center_status: str = "activated"
@@ -1038,26 +1063,50 @@ def _ensure_database_exists(self):
10381063

10391064
# TODO
10401065
def _vector_index_exists(self, index_name: str = "memory_vector_index") -> bool:
1041-
raise NotImplementedError
1066+
return False
10421067

1043-
# TODO
10441068
def _create_vector_index(
10451069
self, label: str, vector_property: str, dimensions: int, index_name: str
10461070
) -> None:
10471071
"""
10481072
Create a vector index for the specified property in the label.
10491073
"""
1050-
raise NotImplementedError
1074+
create_vector_index = f"""
1075+
CREATE VECTOR INDEX IF NOT EXISTS {index_name}
1076+
ON NODE Memory::{vector_property}
1077+
OPTIONS {{
1078+
DIM: {dimensions},
1079+
METRIC: L2,
1080+
TYPE: IVF,
1081+
NLIST: 100,
1082+
TRAINSIZE: 1000
1083+
}}
1084+
FOR memory_graph
1085+
"""
1086+
self.client.execute(create_vector_index)
10511087

1052-
# TODO
10531088
def _create_basic_property_indexes(self) -> None:
10541089
"""
1055-
Create standard B-tree indexes on memory_type, created_at,
1090+
Create standard B-tree indexes on status, memory_type, created_at
10561091
and updated_at fields.
10571092
Create standard B-tree indexes on user_name when use Shared Database
1058-
Multi-Tenant Mode
1093+
Multi-Tenant Mode.
10591094
"""
1060-
raise NotImplementedError
1095+
fields = ["status", "memory_type", "created_at", "updated_at"]
1096+
if not self.config.use_multi_db:
1097+
fields.append("user_name")
1098+
1099+
for field in fields:
1100+
index_name = f"idx_memory_{field}"
1101+
gql = f"""
1102+
CREATE INDEX IF NOT EXISTS {index_name} ON NODE Memory({field})
1103+
FOR memory_graph
1104+
"""
1105+
try:
1106+
self.client.execute(gql)
1107+
logger.info(f"✅ Created index: {index_name} on field {field}")
1108+
except Exception as e:
1109+
logger.error(f"❌ Failed to create index {index_name}: {e}")
10611110

10621111
def _index_exists(self, index_name: str) -> bool:
10631112
"""
@@ -1086,4 +1135,11 @@ def _parse_node(self, value: ValueWrapper) -> Any:
10861135
self._parse_node(v) if isinstance(v, ValueWrapper) else v for v in primitive_value
10871136
]
10881137

1138+
if type(primitive_value).__name__ == "NVector":
1139+
try:
1140+
return list(primitive_value.values)
1141+
except Exception as e3:
1142+
logger.warning(f"Failed to convert NVector: {primitive_value}, error: {e3}")
1143+
return str(primitive_value)
1144+
10891145
return primitive_value

0 commit comments

Comments
 (0)