Skip to content

Commit ce46104

Browse files
committed
fix: search bug;
1 parent 084b241 commit ce46104

File tree

2 files changed

+27
-11
lines changed

2 files changed

+27
-11
lines changed

examples/basic_modules/nebular_example.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ def show(nebular_data):
2222
from memos.graph_dbs.neo4j import Neo4jGraphDB
2323

2424
tree_config = Neo4jGraphDBConfig.from_json_file("../../examples/data/config/neo4j_config.json")
25-
tree_config.use_multi_db = False
25+
tree_config.use_multi_db = True
2626
tree_config.db_name = "nebular-show"
27-
tree_config.user_name = "nebular-show"
2827

2928
neo4j_db = Neo4jGraphDB(tree_config)
3029
neo4j_db.clear()
@@ -108,7 +107,7 @@ def example_shared_db(db_name: str = "shared-traval-group"):
108107
Multiple users' data in the same Neo4j DB with user_name as a tag.
109108
"""
110109
# users
111-
user_list = ["root"]
110+
user_list = ["travel_member_alice", "travel_member_bob"]
112111

113112
for user_name in user_list:
114113
# Step 1: Build factory config
@@ -198,15 +197,19 @@ def example_shared_db(db_name: str = "shared-traval-group"):
198197
all_graph_data = graph.export_graph()
199198
print(str(all_graph_data)[:1000])
200199

200+
all_nodes = graph.export_graph()
201+
show(all_nodes)
202+
201203
# Step 6: Search for alice's data only
202204
print("\n=== Search for travel_member_alice ===")
203205
config_alice = GraphDBConfigFactory(
204206
backend="nebular",
205207
config={
206-
"hosts": json.loads(os.getenv("NEBULAR_HOSTS", "localhost")),
207-
"user_name": os.getenv("NEBULAR_USER", "root"),
208+
"uri": json.loads(os.getenv("NEBULAR_HOSTS", "localhost")),
209+
"user": os.getenv("NEBULAR_USER", "root"),
208210
"password": os.getenv("NEBULAR_PASSWORD", "xxxxxx"),
209211
"space": db_name,
212+
"user_name": user_list[0],
210213
"auto_create": True,
211214
"embedding_dimension": 3072,
212215
"use_multi_db": False,
@@ -339,7 +342,7 @@ def run_user_session(
339342
graph.update_node(
340343
concept_items[0].id, {"confidence": 99.0, "created_at": "2025-07-24T20:11:56.375687"}
341344
)
342-
graph.remove_oldest_memory("LongTermMemory", keep_latest=3)
345+
graph.remove_oldest_memory("WorkingMemory", keep_latest=1)
343346
graph.delete_edge(topic.id, concept_items[0].id, type="PARENT")
344347
graph.delete_node(concept_items[1].id)
345348

src/memos/graph_dbs/nebular.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from datetime import datetime
22
from typing import Any, Literal
33

4+
import numpy as np
5+
46
from nebulagraph_python.py_data_types import NVector
57
from nebulagraph_python.value_wrapper import ValueWrapper
68

@@ -13,6 +15,12 @@
1315
logger = get_logger(__name__)
1416

1517

18+
def _normalize(vec: list[float]) -> list[float]:
19+
v = np.asarray(vec, dtype=np.float32)
20+
norm = np.linalg.norm(v)
21+
return (v / (norm if norm else 1.0)).tolist()
22+
23+
1624
def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
1725
node_id = item["id"]
1826
memory = item["memory"]
@@ -36,7 +44,7 @@ def _prepare_node_metadata(metadata: dict[str, Any]) -> dict[str, Any]:
3644
# Normalize embedding type
3745
embedding = metadata.get("embedding")
3846
if embedding and isinstance(embedding, list):
39-
metadata["embedding"] = [float(x) for x in embedding]
47+
metadata["embedding"] = _normalize([float(x) for x in embedding])
4048

4149
return metadata
4250

@@ -175,6 +183,9 @@ def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None:
175183
metadata["id"] = id
176184
metadata["memory"] = memory
177185

186+
if "embedding" in metadata and isinstance(metadata["embedding"], list):
187+
metadata["embedding"] = _normalize(metadata["embedding"])
188+
178189
properties = ", ".join(f"{k}: {_format_value(v, k)}" for k, v in metadata.items())
179190
gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
180191

@@ -616,6 +627,7 @@ def search_by_embedding(
616627
- Typical use case: restrict to 'status = activated' to avoid
617628
matching archived or merged nodes.
618629
"""
630+
vector = _normalize(vector)
619631
dim = len(vector)
620632
vector_str = ",".join(f"{float(x)}" for x in vector)
621633
gql_vector = f"VECTOR<{dim}, FLOAT>([{vector_str}])"
@@ -634,11 +646,11 @@ def search_by_embedding(
634646
USE memory_graph
635647
MATCH (n@Memory)
636648
{where_clause}
637-
ORDER BY euclidean(n.embedding, {gql_vector}) ASC
649+
ORDER BY inner_product(n.embedding, {gql_vector}) DESC
638650
APPROXIMATE
639651
LIMIT {top_k}
640-
OPTIONS {{ METRIC: L2, TYPE: IVF, NPROBE: 8 }}
641-
RETURN n.id AS id, euclidean(n.embedding, {gql_vector}) AS score
652+
OPTIONS {{ METRIC: IP, TYPE: IVF, NPROBE: 8 }}
653+
RETURN n.id AS id, inner_product(n.embedding, {gql_vector}) AS score
642654
"""
643655

644656
try:
@@ -653,6 +665,7 @@ def search_by_embedding(
653665
values = row.values()
654666
id_val = values[0].as_string()
655667
score_val = values[1].as_double()
668+
score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score
656669
if threshold is None or score_val <= threshold:
657670
output.append({"id": id_val, "score": score_val})
658671
return output
@@ -1076,7 +1089,7 @@ def _create_vector_index(
10761089
ON NODE Memory::{vector_property}
10771090
OPTIONS {{
10781091
DIM: {dimensions},
1079-
METRIC: L2,
1092+
METRIC: IP,
10801093
TYPE: IVF,
10811094
NLIST: 100,
10821095
TRAINSIZE: 1000

0 commit comments

Comments
 (0)