Skip to content

Commit 176f35b

Browse files
committed
fix index and multi-db-name
1 parent ce46104 commit 176f35b

File tree

2 files changed

+31
-27
lines changed

2 files changed

+31
-27
lines changed

examples/basic_modules/nebular_example.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,6 @@ def example_shared_db(db_name: str = "shared-traval-group"):
197197
all_graph_data = graph.export_graph()
198198
print(str(all_graph_data)[:1000])
199199

200-
all_nodes = graph.export_graph()
201-
show(all_nodes)
202-
203200
# Step 6: Search for alice's data only
204201
print("\n=== Search for travel_member_alice ===")
205202
config_alice = GraphDBConfigFactory(
@@ -320,9 +317,6 @@ def run_user_session(
320317
node = graph.get_node(r["id"])
321318
print("🔍 Search result:", node["memory"])
322319

323-
all_nodes = graph.export_graph()
324-
show(all_nodes)
325-
326320
# === Step 5: Tag-based neighborhood discovery ===
327321
neighbors = graph.get_neighbors_by_tag(["concept"], exclude_ids=[], top_k=2)
328322
print("📎 Tag-related nodes:", [neighbor["memory"] for neighbor in neighbors])

src/memos/graph_dbs/nebular.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def __init__(self, config: NebulaGraphDBConfig):
103103
- hosts: list[str] like ["host1:port", "host2:port"]
104104
- user: str
105105
- password: str
106-
- space: str (optional for basic commands)
106+
- db_name: str (optional for basic commands)
107107
108108
Example config:
109109
{
@@ -122,12 +122,13 @@ def __init__(self, config: NebulaGraphDBConfig):
122122
password=config.get("password"),
123123
)
124124
self.db_name = config.space
125-
self.space = config.get("space")
126125
self.user_name = config.user_name
127126
self.system_db_name = "system" if config.use_multi_db else config.space
128127
if config.auto_create:
129128
self._ensure_database_exists()
130129

130+
self.client.execute(f"SESSION SET GRAPH `{self.db_name}`")
131+
131132
# Create only if not exists
132133
self.create_index(dimensions=config.embedding_dimension)
133134

@@ -140,9 +141,8 @@ def create_index(
140141
dimensions: int = 3072,
141142
index_name: str = "memory_vector_index",
142143
) -> None:
143-
# Create vector index if it doesn't exist
144-
if not self._vector_index_exists(index_name):
145-
self._create_vector_index(label, vector_property, dimensions, index_name)
144+
# Create vector index
145+
self._create_vector_index(label, vector_property, dimensions, index_name)
146146
# Create indexes
147147
self._create_basic_property_indexes()
148148

@@ -354,7 +354,7 @@ def get_node(self, id: str) -> dict[str, Any] | None:
354354
dict: Node properties as key-value pairs, or None if not found.
355355
"""
356356
gql = f"""
357-
USE memory_graph
357+
USE `{self.db_name}`
358358
MATCH (v {{id: '{id}'}})
359359
RETURN v
360360
"""
@@ -451,7 +451,6 @@ def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[
451451
)
452452
return edges
453453

454-
# TODO
455454
def get_neighbors_by_tag(
456455
self,
457456
tags: list[str],
@@ -643,7 +642,7 @@ def search_by_embedding(
643642
where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
644643

645644
gql = f"""
646-
USE memory_graph
645+
USE `{self.db_name}`
647646
MATCH (n@Memory)
648647
{where_clause}
649648
ORDER BY inner_product(n.embedding, {gql_vector}) DESC
@@ -961,11 +960,11 @@ def drop_database(self) -> None:
961960
WARNING: This operation is destructive and cannot be undone.
962961
"""
963962
if self.config.use_multi_db:
964-
self.client.execute(f"DROP GRAPH {self.db_name}")
965-
logger.info(f"Database '{self.db_name}' has been dropped.")
963+
self.client.execute(f"DROP GRAPH `{self.db_name}`")
964+
logger.info(f"Database '`{self.db_name}`' has been dropped.")
966965
else:
967966
raise ValueError(
968-
f"Refusing to drop protected database: {self.db_name} in "
967+
f"Refusing to drop protected database: `{self.db_name}` in "
969968
f"Shared Database Multi-Tenant mode"
970969
)
971970

@@ -1063,21 +1062,17 @@ def _ensure_database_exists(self):
10631062
EDGE PARENT (Memory) -[{user_name STRING}]-> (Memory)
10641063
}
10651064
"""
1066-
create_graph = "CREATE GRAPH IF NOT EXISTS memory_graph TYPED MemoryGraphType"
1067-
set_graph_working = "SESSION SET GRAPH memory_graph"
1065+
create_graph = f"CREATE GRAPH IF NOT EXISTS `{self.db_name}` TYPED MemoryGraphType"
1066+
set_graph_working = f"SESSION SET GRAPH `{self.db_name}`"
10681067

10691068
try:
10701069
self.client.execute(create_tag)
10711070
self.client.execute(create_graph)
10721071
self.client.execute(set_graph_working)
1073-
logger.info("✅ Graph `memory_graph` is now the working graph.")
1072+
logger.info(f"✅ Graph ``{self.db_name}`` is now the working graph.")
10741073
except Exception as e:
10751074
logger.error(f"❌ Failed to create tag: {e}")
10761075

1077-
# TODO
1078-
def _vector_index_exists(self, index_name: str = "memory_vector_index") -> bool:
1079-
return False
1080-
10811076
def _create_vector_index(
10821077
self, label: str, vector_property: str, dimensions: int, index_name: str
10831078
) -> None:
@@ -1094,7 +1089,7 @@ def _create_vector_index(
10941089
NLIST: 100,
10951090
TRAINSIZE: 1000
10961091
}}
1097-
FOR memory_graph
1092+
FOR `{self.db_name}`
10981093
"""
10991094
self.client.execute(create_vector_index)
11001095

@@ -1113,7 +1108,7 @@ def _create_basic_property_indexes(self) -> None:
11131108
index_name = f"idx_memory_{field}"
11141109
gql = f"""
11151110
CREATE INDEX IF NOT EXISTS {index_name} ON NODE Memory({field})
1116-
FOR memory_graph
1111+
FOR `{self.db_name}`
11171112
"""
11181113
try:
11191114
self.client.execute(gql)
@@ -1125,7 +1120,22 @@ def _index_exists(self, index_name: str) -> bool:
11251120
"""
11261121
Check if an index with the given name exists.
11271122
"""
1128-
raise NotImplementedError
1123+
"""
1124+
Check if a vector index with the given name exists in NebulaGraph.
1125+
1126+
Args:
1127+
index_name (str): The name of the index to check.
1128+
1129+
Returns:
1130+
bool: True if the index exists, False otherwise.
1131+
"""
1132+
query = "SHOW VECTOR INDEXES"
1133+
try:
1134+
result = self.client.execute(query)
1135+
return any(row.values()[0].as_string() == index_name for row in result)
1136+
except Exception as e:
1137+
logger.error(f"[Nebula] Failed to check index existence: {e}")
1138+
return False
11291139

11301140
def _parse_node(self, value: ValueWrapper) -> Any:
11311141
if value is None or value.is_null():

0 commit comments

Comments
 (0)