33
44from contextlib import suppress
55from datetime import datetime
6- from queue import Empty , Queue
76from threading import Lock
87from typing import TYPE_CHECKING , Any , ClassVar , Literal
98
1817
1918
2019if TYPE_CHECKING :
21- from nebulagraph_python .client .pool import NebulaPool
20+ from nebulagraph_python import (
21+ NebulaClient ,
22+ )
2223
2324
2425logger = get_logger (__name__ )
@@ -88,141 +89,6 @@ def _normalize_datetime(val):
8889 return str (val )
8990
9091
91- class SessionPoolError (Exception ):
92- pass
93-
94-
95- class SessionPool :
96- @require_python_package (
97- import_name = "nebulagraph_python" ,
98- install_command = "pip install ... @Tianxing" ,
99- install_link = "....." ,
100- )
101- def __init__ (
102- self ,
103- hosts : list [str ],
104- user : str ,
105- password : str ,
106- minsize : int = 1 ,
107- maxsize : int = 10000 ,
108- ):
109- self .hosts = hosts
110- self .user = user
111- self .password = password
112- self .minsize = minsize
113- self .maxsize = maxsize
114- self .pool = Queue (maxsize )
115- self .lock = Lock ()
116-
117- self .clients = []
118-
119- for _ in range (minsize ):
120- self ._create_and_add_client ()
121-
122- @timed
123- def _create_and_add_client (self ):
124- from nebulagraph_python import NebulaClient
125-
126- client = NebulaClient (self .hosts , self .user , self .password )
127- self .pool .put (client )
128- self .clients .append (client )
129-
130- @timed
131- def get_client (self , timeout : float = 5.0 ):
132- try :
133- return self .pool .get (timeout = timeout )
134- except Empty :
135- with self .lock :
136- if len (self .clients ) < self .maxsize :
137- from nebulagraph_python import NebulaClient
138-
139- client = NebulaClient (self .hosts , self .user , self .password )
140- self .clients .append (client )
141- return client
142- raise RuntimeError ("NebulaClientPool exhausted" ) from None
143-
144- @timed
145- def return_client (self , client ):
146- try :
147- client .execute ("YIELD 1" )
148- self .pool .put (client )
149- except Exception :
150- if settings .DEBUG :
151- logger .info ("[Pool] Client dead, replacing..." )
152-
153- self .replace_client (client )
154-
155- @timed
156- def close (self ):
157- for client in self .clients :
158- with suppress (Exception ):
159- client .close ()
160- self .clients .clear ()
161-
162- @timed
163- def get (self ):
164- """
165- Context manager: with pool.get() as client:
166- """
167-
168- class _ClientContext :
169- def __init__ (self , outer ):
170- self .outer = outer
171- self .client = None
172-
173- def __enter__ (self ):
174- self .client = self .outer .get_client ()
175- return self .client
176-
177- def __exit__ (self , exc_type , exc_val , exc_tb ):
178- if self .client :
179- self .outer .return_client (self .client )
180-
181- return _ClientContext (self )
182-
183- @timed
184- def reset_pool (self ):
185- """⚠️ Emergency reset: Close all clients and clear the pool."""
186- logger .warning ("[Pool] Resetting all clients. Existing sessions will be lost." )
187- with self .lock :
188- for client in self .clients :
189- try :
190- client .close ()
191- except Exception :
192- logger .error ("Fail to close!!!" )
193- self .clients .clear ()
194- while not self .pool .empty ():
195- try :
196- self .pool .get_nowait ()
197- except Empty :
198- break
199- for _ in range (self .minsize ):
200- self ._create_and_add_client ()
201- logger .info ("[Pool] Pool has been reset successfully." )
202-
203- @timed
204- def replace_client (self , client ):
205- try :
206- client .close ()
207- except Exception :
208- logger .error ("Fail to close client" )
209-
210- if client in self .clients :
211- self .clients .remove (client )
212-
213- from nebulagraph_python import NebulaClient
214-
215- new_client = NebulaClient (self .hosts , self .user , self .password )
216- self .clients .append (new_client )
217-
218- self .pool .put (new_client )
219-
220- if settings .DEBUG :
221- logger .info (f"[Pool] Replaced dead client with a new one. { new_client } " )
222-
223- return new_client
224-
225-
22692class NebulaGraphDB (BaseGraphDB ):
22793 """
22894 NebulaGraph-based implementation of a graph memory store.
@@ -231,94 +97,102 @@ class NebulaGraphDB(BaseGraphDB):
23197 # ====== shared pool cache & refcount ======
23298 # These are process-local; in a multi-process model each process will
23399 # have its own cache.
234- _POOL_CACHE : ClassVar [dict [str , "NebulaPool " ]] = {}
235- _POOL_REFCOUNT : ClassVar [dict [str , int ]] = {}
236- _POOL_LOCK : ClassVar [Lock ] = Lock ()
100+ _CLIENT_CACHE : ClassVar [dict [str , "NebulaClient " ]] = {}
101+ _CLIENT_REFCOUNT : ClassVar [dict [str , int ]] = {}
102+ _CLIENT_LOCK : ClassVar [Lock ] = Lock ()
237103
238104 @staticmethod
239- def _make_pool_key (cfg : NebulaGraphDBConfig ) -> str :
240- """
241- Build a cache key that captures all connection-affecting options.
242- Keep this key stable and include fields that change the underlying pool behavior.
243- """
244- # NOTE: Do not include tenant-like or query-scope-only fields here.
245- # Only include things that affect the actual TCP/auth/session pool.
105+ def _get_hosts_from_cfg (cfg : NebulaGraphDBConfig ) -> list [str ]:
106+ hosts = getattr (cfg , "uri" , None ) or getattr (cfg , "hosts" , None )
107+ if isinstance (hosts , str ):
108+ return [hosts ]
109+ return list (hosts or [])
110+
111+ @staticmethod
112+ def _make_client_key (cfg : NebulaGraphDBConfig ) -> str :
113+ hosts = NebulaGraphDB ._get_hosts_from_cfg (cfg )
246114 return "|" .join (
247115 [
248- "nebula" ,
249- str ( getattr ( cfg , "uri" , "" ) ),
116+ "nebula-sync " ,
117+ "," . join ( hosts ),
250118 str (getattr (cfg , "user" , "" )),
251119 str (getattr (cfg , "password" , "" )),
252- # pool sizing / tls / timeouts if you have them in config:
253- str (getattr (cfg , "max_client" , 1000 )),
254- # multi-db mode can impact how we use sessions; keep it to be safe
255120 str (getattr (cfg , "use_multi_db" , False )),
256121 ]
257122 )
258123
259124 @classmethod
260- def _get_or_create_shared_pool (cls , cfg : NebulaGraphDBConfig ):
261- """
262- Get a shared NebulaPool from cache or create one if missing.
263- Thread-safe with a lock; maintains a simple refcount.
264- """
265- key = cls ._make_pool_key (cfg )
266-
267- with cls ._POOL_LOCK :
268- pool = cls ._POOL_CACHE .get (key )
269- if pool is None :
270- # Create a new pool and put into cache
271- pool = SessionPool (
272- hosts = cfg .get ("uri" ),
273- user = cfg .get ("user" ),
274- password = cfg .get ("password" ),
275- minsize = 1 ,
276- maxsize = cfg .get ("max_client" , 1000 ),
125+ def _get_or_create_shared_client (cls , cfg : NebulaGraphDBConfig ) -> (tuple )[str , "NebulaClient" ]:
126+ from nebulagraph_python import (
127+ ConnectionConfig ,
128+ NebulaClient ,
129+ SessionConfig ,
130+ SessionPoolConfig ,
131+ )
132+
133+ key = cls ._make_client_key (cfg )
134+ with cls ._CLIENT_LOCK :
135+ client = cls ._CLIENT_CACHE .get (key )
136+ if client is None :
137+ # Connection setting
138+ conn_conf : ConnectionConfig | None = getattr (cfg , "conn_config" , None )
139+ if conn_conf is None :
140+ conn_conf = ConnectionConfig .from_defults (
141+ cls ._get_hosts_from_cfg (cfg ),
142+ getattr (cfg , "ssl_param" , None ),
143+ )
144+
145+ sess_conf = SessionConfig (graph = getattr (cfg , "space" , None ))
146+
147+ pool_conf = SessionPoolConfig (size = int (getattr (cfg , "max_client" , 1000 )))
148+
149+ client = NebulaClient (
150+ hosts = conn_conf .hosts ,
151+ username = cfg .user ,
152+ password = cfg .password ,
153+ conn_config = conn_conf ,
154+ session_config = sess_conf ,
155+ session_pool_config = pool_conf ,
277156 )
278- cls ._POOL_CACHE [key ] = pool
279- cls ._POOL_REFCOUNT [key ] = 0
280- logger .info (f"[NebulaGraphDB ] Created new shared NebulaPool for key={ key } " )
157+ cls ._CLIENT_CACHE [key ] = client
158+ cls ._CLIENT_REFCOUNT [key ] = 0
159+ logger .info (f"[NebulaGraphDBSync ] Created shared NebulaClient key={ key } " )
281160
282- # Increase refcount for the caller
283- cls ._POOL_REFCOUNT [key ] = cls ._POOL_REFCOUNT .get (key , 0 ) + 1
284- return key , pool
161+ cls ._CLIENT_REFCOUNT [key ] = cls ._CLIENT_REFCOUNT .get (key , 0 ) + 1
162+ return key , client
285163
286164 @classmethod
287- def _release_shared_pool (cls , key : str ):
288- """
289- Decrease refcount for the given pool key; only close when refcount hits zero.
290- """
291- with cls ._POOL_LOCK :
292- if key not in cls ._POOL_CACHE :
165+ def _release_shared_client (cls , key : str ):
166+ with cls ._CLIENT_LOCK :
167+ if key not in cls ._CLIENT_CACHE :
293168 return
294- cls ._POOL_REFCOUNT [key ] = max (0 , cls ._POOL_REFCOUNT .get (key , 0 ) - 1 )
295- if cls ._POOL_REFCOUNT [key ] == 0 :
169+ cls ._CLIENT_REFCOUNT [key ] = max (0 , cls ._CLIENT_REFCOUNT .get (key , 0 ) - 1 )
170+ if cls ._CLIENT_REFCOUNT [key ] == 0 :
296171 try :
297- cls ._POOL_CACHE [key ].close ()
172+ cls ._CLIENT_CACHE [key ].close ()
298173 except Exception as e :
299- logger .warning (f"[NebulaGraphDB ] Error closing shared pool : { e } " )
174+ logger .warning (f"[NebulaGraphDBSync ] Error closing client : { e } " )
300175 finally :
301- cls ._POOL_CACHE .pop (key , None )
302- cls ._POOL_REFCOUNT .pop (key , None )
303- logger .info (f"[NebulaGraphDB ] Closed and removed shared pool key={ key } " )
176+ cls ._CLIENT_CACHE .pop (key , None )
177+ cls ._CLIENT_REFCOUNT .pop (key , None )
178+ logger .info (f"[NebulaGraphDBSync ] Closed & removed client key={ key } " )
304179
305180 @classmethod
306- def close_all_shared_pools (cls ):
307- """Force close all cached pools. Call this on graceful shutdown."""
308- with cls ._POOL_LOCK :
309- for key , pool in list (cls ._POOL_CACHE .items ()):
181+ def close_all_shared_clients (cls ):
182+ with cls ._CLIENT_LOCK :
183+ for key , client in list (cls ._CLIENT_CACHE .items ()):
310184 try :
311- pool .close ()
185+ client .close ()
312186 except Exception as e :
313- logger .warning (f"[NebulaGraphDB ] Error closing pool key= { key } : { e } " )
187+ logger .warning (f"[NebulaGraphDBSync ] Error closing client { key } : { e } " )
314188 finally :
315- logger .info (f"[NebulaGraphDB ] Closed pool key={ key } " )
316- cls ._POOL_CACHE .clear ()
317- cls ._POOL_REFCOUNT .clear ()
189+ logger .info (f"[NebulaGraphDBSync ] Closed client key={ key } " )
190+ cls ._CLIENT_CACHE .clear ()
191+ cls ._CLIENT_REFCOUNT .clear ()
318192
319193 @require_python_package (
320194 import_name = "nebulagraph_python" ,
321- install_command = "pip install ... @Tianxing " ,
195+ install_command = "pip install nebulagraph-python>=5.1.1 " ,
322196 install_link = "....." ,
323197 )
324198 def __init__ (self , config : NebulaGraphDBConfig ):
@@ -376,34 +250,35 @@ def __init__(self, config: NebulaGraphDBConfig):
376250
377251 # ---- NEW: pool acquisition strategy
378252 # Get or create a shared pool from the class-level cache
379- self ._pool_key , self .pool = self ._get_or_create_shared_pool (config )
380- self ._owns_pool = True # We manage refcount for this instance
253+ self ._client_key , self ._client = self ._get_or_create_shared_client (config )
254+ self ._owns_client = True
381255
382256 # auto-create graph type / graph / index if needed
383- if config . auto_create :
257+ if getattr ( config , " auto_create" , False ) :
384258 self ._ensure_database_exists ()
385259
386260 self .execute_query (f"SESSION SET GRAPH `{ self .db_name } `" )
387261
388262 # Create only if not exists
389263 self .create_index (dimensions = config .embedding_dimension )
390-
391264 logger .info ("Connected to NebulaGraph successfully." )
392265
393266 @timed
394267 def execute_query (self , gql : str , timeout : float = 10.0 , auto_set_db : bool = True ):
395- with self .pool .get () as client :
396- try :
397- if auto_set_db and self .db_name :
398- client .execute (f"SESSION SET GRAPH `{ self .db_name } `" )
399- return client .execute (gql , timeout = timeout )
400-
401- except Exception as e :
402- if "Session not found" in str (e ) or "Connection not established" in str (e ):
403- logger .warning (f"[execute_query] { e !s} , replacing client..." )
404- self .pool .replace_client (client )
405- return self .execute_query (gql , timeout , auto_set_db )
406- raise
268+ try :
269+ if auto_set_db and self .db_name :
270+ self ._client .execute (f"SESSION SET GRAPH `{ self .db_name } `" )
271+ return self ._client .execute (gql , timeout = timeout )
272+ except Exception as e :
273+ emsg = str (e )
274+ if "Session not found" in emsg or "Connection not established" in emsg :
275+ logger .warning (f"[execute_query] { e !s} , retry once..." )
276+ try :
277+ return self ._client .execute (gql , timeout = timeout )
278+ except Exception :
279+ logger .exception ("[execute_query] retry failed" )
280+ raise
281+ raise
407282
408283 @timed
409284 def close (self ):
@@ -414,13 +289,13 @@ def close(self):
414289 - If pool was acquired via shared cache, decrement refcount and close
415290 when the last owner releases it.
416291 """
417- if not self ._owns_pool :
418- logger .debug ("[NebulaGraphDB ] close() skipped (injected pool )." )
292+ if not self ._owns_client :
293+ logger .debug ("[NebulaGraphDBSync ] close() skipped (injected client )." )
419294 return
420- if self ._pool_key :
421- self ._release_shared_pool (self ._pool_key )
422- self ._pool_key = None
423- self .pool = None
295+ if self ._client_key :
296+ self ._release_shared_client (self ._client_key )
297+ self ._client_key = None
298+ self ._client = None
424299
425300 # NOTE: __del__ is best-effort; do not rely on GC order.
426301 def __del__ (self ):
0 commit comments