2424logger = get_logger (__name__ )
2525
2626
27+ _TRANSIENT_ERR_KEYS = (
28+ "Session not found" ,
29+ "Connection not established" ,
30+ "timeout" ,
31+ "deadline exceeded" ,
32+ "Broken pipe" ,
33+ "EOFError" ,
34+ "socket closed" ,
35+ "connection reset" ,
36+ "connection refused" ,
37+ )
38+
39+
2740@timed
2841def _normalize (vec : list [float ]) -> list [float ]:
2942 v = np .asarray (vec , dtype = np .float32 )
@@ -99,6 +112,7 @@ class NebulaGraphDB(BaseGraphDB):
99112 _CLIENT_CACHE : ClassVar [dict [str , "NebulaClient" ]] = {}
100113 _CLIENT_REFCOUNT : ClassVar [dict [str , int ]] = {}
101114 _CLIENT_LOCK : ClassVar [Lock ] = Lock ()
115+ _CLIENT_INIT_DONE : ClassVar [set [str ]] = set ()
102116
103117 @staticmethod
104118 def _get_hosts_from_cfg (cfg : NebulaGraphDBConfig ) -> list [str ]:
@@ -115,13 +129,53 @@ def _make_client_key(cfg: NebulaGraphDBConfig) -> str:
115129 "nebula-sync" ,
116130 "," .join (hosts ),
117131 str (getattr (cfg , "user" , "" )),
118- str (getattr (cfg , "password" , "" )),
119132 str (getattr (cfg , "use_multi_db" , False )),
133+ str (getattr (cfg , "space" , "" )),
120134 ]
121135 )
122136
123137 @classmethod
124- def _get_or_create_shared_client (cls , cfg : NebulaGraphDBConfig ) -> (tuple )[str , "NebulaClient" ]:
138+ def _bootstrap_admin (cls , cfg : NebulaGraphDBConfig , client : "NebulaClient" ) -> "NebulaGraphDB" :
139+ tmp = object .__new__ (NebulaGraphDB )
140+ tmp .config = cfg
141+ tmp .db_name = cfg .space
142+ tmp .user_name = getattr (cfg , "user_name" , None )
143+ tmp .embedding_dimension = getattr (cfg , "embedding_dimension" , 3072 )
144+ tmp .default_memory_dimension = 3072
145+ tmp .common_fields = {
146+ "id" ,
147+ "memory" ,
148+ "user_name" ,
149+ "user_id" ,
150+ "session_id" ,
151+ "status" ,
152+ "key" ,
153+ "confidence" ,
154+ "tags" ,
155+ "created_at" ,
156+ "updated_at" ,
157+ "memory_type" ,
158+ "sources" ,
159+ "source" ,
160+ "node_type" ,
161+ "visibility" ,
162+ "usage" ,
163+ "background" ,
164+ }
165+ tmp .base_fields = set (tmp .common_fields ) - {"usage" }
166+ tmp .heavy_fields = {"usage" }
167+ tmp .dim_field = (
168+ f"embedding_{ tmp .embedding_dimension } "
169+ if str (tmp .embedding_dimension ) != str (tmp .default_memory_dimension )
170+ else "embedding"
171+ )
172+ tmp .system_db_name = "system" if getattr (cfg , "use_multi_db" , False ) else cfg .space
173+ tmp ._client = client
174+ tmp ._owns_client = False
175+ return tmp
176+
177+ @classmethod
178+ def _get_or_create_shared_client (cls , cfg : NebulaGraphDBConfig ) -> tuple [str , "NebulaClient" ]:
125179 from nebulagraph_python import (
126180 ConnectionConfig ,
127181 NebulaClient ,
@@ -159,7 +213,60 @@ def _get_or_create_shared_client(cls, cfg: NebulaGraphDBConfig) -> (tuple)[str,
159213 logger .info (f"[NebulaGraphDBSync] Created shared NebulaClient key={ key } " )
160214
161215 cls ._CLIENT_REFCOUNT [key ] = cls ._CLIENT_REFCOUNT .get (key , 0 ) + 1
162- return key , client
216+
217+ if getattr (cfg , "auto_create" , False ) and key not in cls ._CLIENT_INIT_DONE :
218+ try :
219+ pass
220+ finally :
221+ pass
222+
223+ if getattr (cfg , "auto_create" , False ) and key not in cls ._CLIENT_INIT_DONE :
224+ with cls ._CLIENT_LOCK :
225+ if key not in cls ._CLIENT_INIT_DONE :
226+ admin = cls ._bootstrap_admin (cfg , client )
227+ try :
228+ admin ._ensure_database_exists ()
229+ admin ._create_basic_property_indexes ()
230+ admin ._create_vector_index (
231+ label = "Memory" ,
232+ vector_property = admin .dim_field ,
233+ dimensions = int (
234+ admin .embedding_dimension or admin .default_memory_dimension
235+ ),
236+ index_name = "memory_vector_index" ,
237+ )
238+ cls ._CLIENT_INIT_DONE .add (key )
239+ logger .info ("[NebulaGraphDBSync] One-time init done" )
240+ except Exception :
241+ logger .exception ("[NebulaGraphDBSync] One-time init failed" )
242+
243+ return key , client
244+
245+ def _refresh_client (self ):
246+ """
247+ refresh NebulaClient:
248+ """
249+ old_key = getattr (self , "_client_key" , None )
250+ if not old_key :
251+ return
252+
253+ cls = self .__class__
254+ with cls ._CLIENT_LOCK :
255+ try :
256+ if old_key in cls ._CLIENT_CACHE :
257+ try :
258+ cls ._CLIENT_CACHE [old_key ].close ()
259+ except Exception as e :
260+ logger .warning (f"[refresh_client] close old client error: { e } " )
261+ finally :
262+ cls ._CLIENT_CACHE .pop (old_key , None )
263+ finally :
264+ cls ._CLIENT_REFCOUNT [old_key ] = 0
265+
266+ new_key , new_client = cls ._get_or_create_shared_client (self .config )
267+ self ._client_key = new_key
268+ self ._client = new_client
269+ logger .info (f"[NebulaGraphDBSync] client refreshed: { old_key } -> { new_key } " )
163270
164271 @classmethod
165272 def _release_shared_client (cls , key : str ):
@@ -253,32 +360,27 @@ def __init__(self, config: NebulaGraphDBConfig):
253360 self ._client_key , self ._client = self ._get_or_create_shared_client (config )
254361 self ._owns_client = True
255362
256- # auto-create graph type / graph / index if needed
257- if getattr (config , "auto_create" , False ):
258- self ._ensure_database_exists ()
259-
260- # Create only if not exists
261- self .create_index (dimensions = config .embedding_dimension )
262363 logger .info ("Connected to NebulaGraph successfully." )
263364
264365 @timed
265366 def execute_query (self , gql : str , timeout : float = 60.0 , auto_set_db : bool = True ):
266- try :
367+ def _wrap_use_db ( q : str ) -> str :
267368 if auto_set_db and self .db_name :
268- gql = f"""USE `{ self .db_name } `
269- { gql } """
270- return self ._client .execute (gql , timeout = timeout )
369+ return f"USE `{ self .db_name } `\n { q } "
370+ return q
371+
372+ try :
373+ return self ._client .execute (_wrap_use_db (gql ), timeout = timeout )
374+
271375 except Exception as e :
272376 emsg = str (e )
273- if "Session not found" in emsg or "Connection not established" in emsg :
274- logger .warning (f"[execute_query] { e !s} , retry once..." )
377+ if any ( k . lower () in emsg . lower () for k in _TRANSIENT_ERR_KEYS ) :
378+ logger .warning (f"[execute_query] { e !s} → refreshing session pool and retry once..." )
275379 try :
276- if auto_set_db and self .db_name :
277- gql = f"""USE `{ self .db_name } `
278- { gql } """
279- return self ._client .execute (gql , timeout = timeout )
380+ self ._refresh_client ()
381+ return self ._client .execute (_wrap_use_db (gql ), timeout = timeout )
280382 except Exception :
281- logger .exception ("[execute_query] retry failed" )
383+ logger .exception ("[execute_query] retry after refresh failed" )
282384 raise
283385 raise
284386
@@ -931,7 +1033,7 @@ def search_by_embedding(
9311033 id_val = values [0 ].as_string ()
9321034 score_val = values [1 ].as_double ()
9331035 score_val = (score_val + 1 ) / 2 # align to neo4j, Normalized Cosine Score
934- if threshold is None or score_val < = threshold :
1036+ if threshold is None or score_val > = threshold :
9351037 output .append ({"id" : id_val , "score" : score_val })
9361038 return output
9371039 except Exception as e :
@@ -1261,6 +1363,7 @@ def get_structure_optimization_candidates(
12611363 where_clause += f' AND n.user_name = "{ self .config .user_name } "'
12621364
12631365 return_fields = self ._build_return_fields (include_embedding )
1366+ return_fields += f", n.{ self .dim_field } AS { self .dim_field } "
12641367
12651368 query = f"""
12661369 MATCH (n@Memory)
@@ -1272,11 +1375,16 @@ def get_structure_optimization_candidates(
12721375 """
12731376
12741377 candidates = []
1378+ node_ids = set ()
12751379 try :
12761380 results = self .execute_query (query )
12771381 for row in results :
12781382 props = {k : v .value for k , v in row .items ()}
1279- candidates .append (self ._parse_node (props ))
1383+ node = self ._parse_node (props )
1384+ node_id = node ["id" ]
1385+ if node_id not in node_ids :
1386+ candidates .append (node )
1387+ node_ids .add (node_id )
12801388 except Exception as e :
12811389 logger .error (f"Failed : { e } , traceback: { traceback .format_exc ()} " )
12821390 return candidates
0 commit comments