Skip to content

Commit 301178d

Browse files
authored
feat: nebula&reorganize update (#322)
* feat: update nebula to nebula 5.1.1 * fix: bug in nebula and manager * feat: update product * test: update * fix: bug in finding reorganize node * fix: duplicate bug * chore: delelte useless annotation * feat: update nebula init * feat: update reorganize
1 parent beb0e07 commit 301178d

File tree

5 files changed

+206
-48
lines changed

5 files changed

+206
-48
lines changed

src/memos/graph_dbs/nebular.py

Lines changed: 130 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,19 @@
2424
logger = get_logger(__name__)
2525

2626

27+
_TRANSIENT_ERR_KEYS = (
28+
"Session not found",
29+
"Connection not established",
30+
"timeout",
31+
"deadline exceeded",
32+
"Broken pipe",
33+
"EOFError",
34+
"socket closed",
35+
"connection reset",
36+
"connection refused",
37+
)
38+
39+
2740
@timed
2841
def _normalize(vec: list[float]) -> list[float]:
2942
v = np.asarray(vec, dtype=np.float32)
@@ -99,6 +112,7 @@ class NebulaGraphDB(BaseGraphDB):
99112
_CLIENT_CACHE: ClassVar[dict[str, "NebulaClient"]] = {}
100113
_CLIENT_REFCOUNT: ClassVar[dict[str, int]] = {}
101114
_CLIENT_LOCK: ClassVar[Lock] = Lock()
115+
_CLIENT_INIT_DONE: ClassVar[set[str]] = set()
102116

103117
@staticmethod
104118
def _get_hosts_from_cfg(cfg: NebulaGraphDBConfig) -> list[str]:
@@ -115,13 +129,53 @@ def _make_client_key(cfg: NebulaGraphDBConfig) -> str:
115129
"nebula-sync",
116130
",".join(hosts),
117131
str(getattr(cfg, "user", "")),
118-
str(getattr(cfg, "password", "")),
119132
str(getattr(cfg, "use_multi_db", False)),
133+
str(getattr(cfg, "space", "")),
120134
]
121135
)
122136

123137
@classmethod
124-
def _get_or_create_shared_client(cls, cfg: NebulaGraphDBConfig) -> (tuple)[str, "NebulaClient"]:
138+
def _bootstrap_admin(cls, cfg: NebulaGraphDBConfig, client: "NebulaClient") -> "NebulaGraphDB":
139+
tmp = object.__new__(NebulaGraphDB)
140+
tmp.config = cfg
141+
tmp.db_name = cfg.space
142+
tmp.user_name = getattr(cfg, "user_name", None)
143+
tmp.embedding_dimension = getattr(cfg, "embedding_dimension", 3072)
144+
tmp.default_memory_dimension = 3072
145+
tmp.common_fields = {
146+
"id",
147+
"memory",
148+
"user_name",
149+
"user_id",
150+
"session_id",
151+
"status",
152+
"key",
153+
"confidence",
154+
"tags",
155+
"created_at",
156+
"updated_at",
157+
"memory_type",
158+
"sources",
159+
"source",
160+
"node_type",
161+
"visibility",
162+
"usage",
163+
"background",
164+
}
165+
tmp.base_fields = set(tmp.common_fields) - {"usage"}
166+
tmp.heavy_fields = {"usage"}
167+
tmp.dim_field = (
168+
f"embedding_{tmp.embedding_dimension}"
169+
if str(tmp.embedding_dimension) != str(tmp.default_memory_dimension)
170+
else "embedding"
171+
)
172+
tmp.system_db_name = "system" if getattr(cfg, "use_multi_db", False) else cfg.space
173+
tmp._client = client
174+
tmp._owns_client = False
175+
return tmp
176+
177+
@classmethod
178+
def _get_or_create_shared_client(cls, cfg: NebulaGraphDBConfig) -> tuple[str, "NebulaClient"]:
125179
from nebulagraph_python import (
126180
ConnectionConfig,
127181
NebulaClient,
@@ -159,7 +213,60 @@ def _get_or_create_shared_client(cls, cfg: NebulaGraphDBConfig) -> (tuple)[str,
159213
logger.info(f"[NebulaGraphDBSync] Created shared NebulaClient key={key}")
160214

161215
cls._CLIENT_REFCOUNT[key] = cls._CLIENT_REFCOUNT.get(key, 0) + 1
162-
return key, client
216+
217+
if getattr(cfg, "auto_create", False) and key not in cls._CLIENT_INIT_DONE:
218+
try:
219+
pass
220+
finally:
221+
pass
222+
223+
if getattr(cfg, "auto_create", False) and key not in cls._CLIENT_INIT_DONE:
224+
with cls._CLIENT_LOCK:
225+
if key not in cls._CLIENT_INIT_DONE:
226+
admin = cls._bootstrap_admin(cfg, client)
227+
try:
228+
admin._ensure_database_exists()
229+
admin._create_basic_property_indexes()
230+
admin._create_vector_index(
231+
label="Memory",
232+
vector_property=admin.dim_field,
233+
dimensions=int(
234+
admin.embedding_dimension or admin.default_memory_dimension
235+
),
236+
index_name="memory_vector_index",
237+
)
238+
cls._CLIENT_INIT_DONE.add(key)
239+
logger.info("[NebulaGraphDBSync] One-time init done")
240+
except Exception:
241+
logger.exception("[NebulaGraphDBSync] One-time init failed")
242+
243+
return key, client
244+
245+
def _refresh_client(self):
246+
"""
247+
refresh NebulaClient:
248+
"""
249+
old_key = getattr(self, "_client_key", None)
250+
if not old_key:
251+
return
252+
253+
cls = self.__class__
254+
with cls._CLIENT_LOCK:
255+
try:
256+
if old_key in cls._CLIENT_CACHE:
257+
try:
258+
cls._CLIENT_CACHE[old_key].close()
259+
except Exception as e:
260+
logger.warning(f"[refresh_client] close old client error: {e}")
261+
finally:
262+
cls._CLIENT_CACHE.pop(old_key, None)
263+
finally:
264+
cls._CLIENT_REFCOUNT[old_key] = 0
265+
266+
new_key, new_client = cls._get_or_create_shared_client(self.config)
267+
self._client_key = new_key
268+
self._client = new_client
269+
logger.info(f"[NebulaGraphDBSync] client refreshed: {old_key} -> {new_key}")
163270

164271
@classmethod
165272
def _release_shared_client(cls, key: str):
@@ -253,32 +360,27 @@ def __init__(self, config: NebulaGraphDBConfig):
253360
self._client_key, self._client = self._get_or_create_shared_client(config)
254361
self._owns_client = True
255362

256-
# auto-create graph type / graph / index if needed
257-
if getattr(config, "auto_create", False):
258-
self._ensure_database_exists()
259-
260-
# Create only if not exists
261-
self.create_index(dimensions=config.embedding_dimension)
262363
logger.info("Connected to NebulaGraph successfully.")
263364

264365
@timed
265366
def execute_query(self, gql: str, timeout: float = 60.0, auto_set_db: bool = True):
266-
try:
367+
def _wrap_use_db(q: str) -> str:
267368
if auto_set_db and self.db_name:
268-
gql = f"""USE `{self.db_name}`
269-
{gql}"""
270-
return self._client.execute(gql, timeout=timeout)
369+
return f"USE `{self.db_name}`\n{q}"
370+
return q
371+
372+
try:
373+
return self._client.execute(_wrap_use_db(gql), timeout=timeout)
374+
271375
except Exception as e:
272376
emsg = str(e)
273-
if "Session not found" in emsg or "Connection not established" in emsg:
274-
logger.warning(f"[execute_query] {e!s}, retry once...")
377+
if any(k.lower() in emsg.lower() for k in _TRANSIENT_ERR_KEYS):
378+
logger.warning(f"[execute_query] {e!s} → refreshing session pool and retry once...")
275379
try:
276-
if auto_set_db and self.db_name:
277-
gql = f"""USE `{self.db_name}`
278-
{gql}"""
279-
return self._client.execute(gql, timeout=timeout)
380+
self._refresh_client()
381+
return self._client.execute(_wrap_use_db(gql), timeout=timeout)
280382
except Exception:
281-
logger.exception("[execute_query] retry failed")
383+
logger.exception("[execute_query] retry after refresh failed")
282384
raise
283385
raise
284386

@@ -931,7 +1033,7 @@ def search_by_embedding(
9311033
id_val = values[0].as_string()
9321034
score_val = values[1].as_double()
9331035
score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score
934-
if threshold is None or score_val <= threshold:
1036+
if threshold is None or score_val >= threshold:
9351037
output.append({"id": id_val, "score": score_val})
9361038
return output
9371039
except Exception as e:
@@ -1261,6 +1363,7 @@ def get_structure_optimization_candidates(
12611363
where_clause += f' AND n.user_name = "{self.config.user_name}"'
12621364

12631365
return_fields = self._build_return_fields(include_embedding)
1366+
return_fields += f", n.{self.dim_field} AS {self.dim_field}"
12641367

12651368
query = f"""
12661369
MATCH (n@Memory)
@@ -1272,11 +1375,16 @@ def get_structure_optimization_candidates(
12721375
"""
12731376

12741377
candidates = []
1378+
node_ids = set()
12751379
try:
12761380
results = self.execute_query(query)
12771381
for row in results:
12781382
props = {k: v.value for k, v in row.items()}
1279-
candidates.append(self._parse_node(props))
1383+
node = self._parse_node(props)
1384+
node_id = node["id"]
1385+
if node_id not in node_ids:
1386+
candidates.append(node)
1387+
node_ids.add(node_id)
12801388
except Exception as e:
12811389
logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}")
12821390
return candidates

src/memos/memories/textual/tree_text_memory/organize/handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,12 @@ def detect(self, memory, top_k: int = 5, scope=None):
5252
]
5353
result = self.llm.generate(prompt).strip()
5454
if result == "contradictory":
55-
logger.warning(
55+
logger.info(
5656
f'detected "{memory.memory}" <==CONFLICT==> "{embedding_candidate.memory}"'
5757
)
5858
detected_relationships.append([memory, embedding_candidate, "contradictory"])
5959
elif result == "redundant":
60-
logger.warning(
60+
logger.info(
6161
f'detected "{memory.memory}" <==REDUNDANT==> "{embedding_candidate.memory}"'
6262
)
6363
detected_relationships.append([memory, embedding_candidate, "redundant"])

src/memos/memories/textual/tree_text_memory/organize/manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def add(self, memories: list[TextualMemoryItem]) -> list[str]:
5858

5959
with ContextThreadPoolExecutor(max_workers=8) as executor:
6060
futures = {executor.submit(self._process_memory, m): m for m in memories}
61-
for future in as_completed(futures):
61+
for future in as_completed(futures, timeout=60):
6262
try:
6363
ids = future.result()
6464
added_ids.extend(ids)
@@ -88,7 +88,7 @@ def replace_working_memory(self, memories: list[TextualMemoryItem]) -> None:
8888
executor.submit(self._add_memory_to_db, memory, "WorkingMemory")
8989
for memory in working_memory_top_k
9090
]
91-
for future in as_completed(futures):
91+
for future in as_completed(futures, timeout=60):
9292
try:
9393
future.result()
9494
except Exception as e:

src/memos/memories/textual/tree_text_memory/organize/relation_reason_detector.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def process_node(self, node: GraphDBNode, exclude_ids: list[str], top_k: int = 5
4646
"sequence_links": [],
4747
"aggregate_nodes": [],
4848
}
49-
49+
"""
5050
nearest = self.graph_store.get_neighbors_by_tag(
5151
tags=node.metadata.tags,
5252
exclude_ids=exclude_ids,
@@ -55,7 +55,6 @@ def process_node(self, node: GraphDBNode, exclude_ids: list[str], top_k: int = 5
5555
)
5656
nearest = [GraphDBNode(**cand_data) for cand_data in nearest]
5757
58-
"""
5958
# 1) Pairwise relations (including CAUSE/CONDITION/CONFLICT)
6059
pairwise = self._detect_pairwise_causal_condition_relations(node, nearest)
6160
results["relations"].extend(pairwise["relations"])

0 commit comments

Comments
 (0)