Skip to content

Commit 02b0983

Browse files
authored
feat: update nebula to nebula 5.1.1 (#311)
1 parent 065a378 commit 02b0983

File tree

1 file changed

+95
-220
lines changed

1 file changed

+95
-220
lines changed

src/memos/graph_dbs/nebular.py

Lines changed: 95 additions & 220 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from contextlib import suppress
55
from datetime import datetime
6-
from queue import Empty, Queue
76
from threading import Lock
87
from typing import TYPE_CHECKING, Any, ClassVar, Literal
98

@@ -18,7 +17,9 @@
1817

1918

2019
if TYPE_CHECKING:
21-
from nebulagraph_python.client.pool import NebulaPool
20+
from nebulagraph_python import (
21+
NebulaClient,
22+
)
2223

2324

2425
logger = get_logger(__name__)
@@ -88,141 +89,6 @@ def _normalize_datetime(val):
8889
return str(val)
8990

9091

91-
class SessionPoolError(Exception):
92-
pass
93-
94-
95-
class SessionPool:
96-
@require_python_package(
97-
import_name="nebulagraph_python",
98-
install_command="pip install ... @Tianxing",
99-
install_link=".....",
100-
)
101-
def __init__(
102-
self,
103-
hosts: list[str],
104-
user: str,
105-
password: str,
106-
minsize: int = 1,
107-
maxsize: int = 10000,
108-
):
109-
self.hosts = hosts
110-
self.user = user
111-
self.password = password
112-
self.minsize = minsize
113-
self.maxsize = maxsize
114-
self.pool = Queue(maxsize)
115-
self.lock = Lock()
116-
117-
self.clients = []
118-
119-
for _ in range(minsize):
120-
self._create_and_add_client()
121-
122-
@timed
123-
def _create_and_add_client(self):
124-
from nebulagraph_python import NebulaClient
125-
126-
client = NebulaClient(self.hosts, self.user, self.password)
127-
self.pool.put(client)
128-
self.clients.append(client)
129-
130-
@timed
131-
def get_client(self, timeout: float = 5.0):
132-
try:
133-
return self.pool.get(timeout=timeout)
134-
except Empty:
135-
with self.lock:
136-
if len(self.clients) < self.maxsize:
137-
from nebulagraph_python import NebulaClient
138-
139-
client = NebulaClient(self.hosts, self.user, self.password)
140-
self.clients.append(client)
141-
return client
142-
raise RuntimeError("NebulaClientPool exhausted") from None
143-
144-
@timed
145-
def return_client(self, client):
146-
try:
147-
client.execute("YIELD 1")
148-
self.pool.put(client)
149-
except Exception:
150-
if settings.DEBUG:
151-
logger.info("[Pool] Client dead, replacing...")
152-
153-
self.replace_client(client)
154-
155-
@timed
156-
def close(self):
157-
for client in self.clients:
158-
with suppress(Exception):
159-
client.close()
160-
self.clients.clear()
161-
162-
@timed
163-
def get(self):
164-
"""
165-
Context manager: with pool.get() as client:
166-
"""
167-
168-
class _ClientContext:
169-
def __init__(self, outer):
170-
self.outer = outer
171-
self.client = None
172-
173-
def __enter__(self):
174-
self.client = self.outer.get_client()
175-
return self.client
176-
177-
def __exit__(self, exc_type, exc_val, exc_tb):
178-
if self.client:
179-
self.outer.return_client(self.client)
180-
181-
return _ClientContext(self)
182-
183-
@timed
184-
def reset_pool(self):
185-
"""⚠️ Emergency reset: Close all clients and clear the pool."""
186-
logger.warning("[Pool] Resetting all clients. Existing sessions will be lost.")
187-
with self.lock:
188-
for client in self.clients:
189-
try:
190-
client.close()
191-
except Exception:
192-
logger.error("Fail to close!!!")
193-
self.clients.clear()
194-
while not self.pool.empty():
195-
try:
196-
self.pool.get_nowait()
197-
except Empty:
198-
break
199-
for _ in range(self.minsize):
200-
self._create_and_add_client()
201-
logger.info("[Pool] Pool has been reset successfully.")
202-
203-
@timed
204-
def replace_client(self, client):
205-
try:
206-
client.close()
207-
except Exception:
208-
logger.error("Fail to close client")
209-
210-
if client in self.clients:
211-
self.clients.remove(client)
212-
213-
from nebulagraph_python import NebulaClient
214-
215-
new_client = NebulaClient(self.hosts, self.user, self.password)
216-
self.clients.append(new_client)
217-
218-
self.pool.put(new_client)
219-
220-
if settings.DEBUG:
221-
logger.info(f"[Pool] Replaced dead client with a new one. {new_client}")
222-
223-
return new_client
224-
225-
22692
class NebulaGraphDB(BaseGraphDB):
22793
"""
22894
NebulaGraph-based implementation of a graph memory store.
@@ -231,94 +97,102 @@ class NebulaGraphDB(BaseGraphDB):
23197
# ====== shared pool cache & refcount ======
23298
# These are process-local; in a multi-process model each process will
23399
# have its own cache.
234-
_POOL_CACHE: ClassVar[dict[str, "NebulaPool"]] = {}
235-
_POOL_REFCOUNT: ClassVar[dict[str, int]] = {}
236-
_POOL_LOCK: ClassVar[Lock] = Lock()
100+
_CLIENT_CACHE: ClassVar[dict[str, "NebulaClient"]] = {}
101+
_CLIENT_REFCOUNT: ClassVar[dict[str, int]] = {}
102+
_CLIENT_LOCK: ClassVar[Lock] = Lock()
237103

238104
@staticmethod
239-
def _make_pool_key(cfg: NebulaGraphDBConfig) -> str:
240-
"""
241-
Build a cache key that captures all connection-affecting options.
242-
Keep this key stable and include fields that change the underlying pool behavior.
243-
"""
244-
# NOTE: Do not include tenant-like or query-scope-only fields here.
245-
# Only include things that affect the actual TCP/auth/session pool.
105+
def _get_hosts_from_cfg(cfg: NebulaGraphDBConfig) -> list[str]:
106+
hosts = getattr(cfg, "uri", None) or getattr(cfg, "hosts", None)
107+
if isinstance(hosts, str):
108+
return [hosts]
109+
return list(hosts or [])
110+
111+
@staticmethod
112+
def _make_client_key(cfg: NebulaGraphDBConfig) -> str:
113+
hosts = NebulaGraphDB._get_hosts_from_cfg(cfg)
246114
return "|".join(
247115
[
248-
"nebula",
249-
str(getattr(cfg, "uri", "")),
116+
"nebula-sync",
117+
",".join(hosts),
250118
str(getattr(cfg, "user", "")),
251119
str(getattr(cfg, "password", "")),
252-
# pool sizing / tls / timeouts if you have them in config:
253-
str(getattr(cfg, "max_client", 1000)),
254-
# multi-db mode can impact how we use sessions; keep it to be safe
255120
str(getattr(cfg, "use_multi_db", False)),
256121
]
257122
)
258123

259124
@classmethod
260-
def _get_or_create_shared_pool(cls, cfg: NebulaGraphDBConfig):
261-
"""
262-
Get a shared NebulaPool from cache or create one if missing.
263-
Thread-safe with a lock; maintains a simple refcount.
264-
"""
265-
key = cls._make_pool_key(cfg)
266-
267-
with cls._POOL_LOCK:
268-
pool = cls._POOL_CACHE.get(key)
269-
if pool is None:
270-
# Create a new pool and put into cache
271-
pool = SessionPool(
272-
hosts=cfg.get("uri"),
273-
user=cfg.get("user"),
274-
password=cfg.get("password"),
275-
minsize=1,
276-
maxsize=cfg.get("max_client", 1000),
125+
def _get_or_create_shared_client(cls, cfg: NebulaGraphDBConfig) -> (tuple)[str, "NebulaClient"]:
126+
from nebulagraph_python import (
127+
ConnectionConfig,
128+
NebulaClient,
129+
SessionConfig,
130+
SessionPoolConfig,
131+
)
132+
133+
key = cls._make_client_key(cfg)
134+
with cls._CLIENT_LOCK:
135+
client = cls._CLIENT_CACHE.get(key)
136+
if client is None:
137+
# Connection setting
138+
conn_conf: ConnectionConfig | None = getattr(cfg, "conn_config", None)
139+
if conn_conf is None:
140+
conn_conf = ConnectionConfig.from_defults(
141+
cls._get_hosts_from_cfg(cfg),
142+
getattr(cfg, "ssl_param", None),
143+
)
144+
145+
sess_conf = SessionConfig(graph=getattr(cfg, "space", None))
146+
147+
pool_conf = SessionPoolConfig(size=int(getattr(cfg, "max_client", 1000)))
148+
149+
client = NebulaClient(
150+
hosts=conn_conf.hosts,
151+
username=cfg.user,
152+
password=cfg.password,
153+
conn_config=conn_conf,
154+
session_config=sess_conf,
155+
session_pool_config=pool_conf,
277156
)
278-
cls._POOL_CACHE[key] = pool
279-
cls._POOL_REFCOUNT[key] = 0
280-
logger.info(f"[NebulaGraphDB] Created new shared NebulaPool for key={key}")
157+
cls._CLIENT_CACHE[key] = client
158+
cls._CLIENT_REFCOUNT[key] = 0
159+
logger.info(f"[NebulaGraphDBSync] Created shared NebulaClient key={key}")
281160

282-
# Increase refcount for the caller
283-
cls._POOL_REFCOUNT[key] = cls._POOL_REFCOUNT.get(key, 0) + 1
284-
return key, pool
161+
cls._CLIENT_REFCOUNT[key] = cls._CLIENT_REFCOUNT.get(key, 0) + 1
162+
return key, client
285163

286164
@classmethod
287-
def _release_shared_pool(cls, key: str):
288-
"""
289-
Decrease refcount for the given pool key; only close when refcount hits zero.
290-
"""
291-
with cls._POOL_LOCK:
292-
if key not in cls._POOL_CACHE:
165+
def _release_shared_client(cls, key: str):
166+
with cls._CLIENT_LOCK:
167+
if key not in cls._CLIENT_CACHE:
293168
return
294-
cls._POOL_REFCOUNT[key] = max(0, cls._POOL_REFCOUNT.get(key, 0) - 1)
295-
if cls._POOL_REFCOUNT[key] == 0:
169+
cls._CLIENT_REFCOUNT[key] = max(0, cls._CLIENT_REFCOUNT.get(key, 0) - 1)
170+
if cls._CLIENT_REFCOUNT[key] == 0:
296171
try:
297-
cls._POOL_CACHE[key].close()
172+
cls._CLIENT_CACHE[key].close()
298173
except Exception as e:
299-
logger.warning(f"[NebulaGraphDB] Error closing shared pool: {e}")
174+
logger.warning(f"[NebulaGraphDBSync] Error closing client: {e}")
300175
finally:
301-
cls._POOL_CACHE.pop(key, None)
302-
cls._POOL_REFCOUNT.pop(key, None)
303-
logger.info(f"[NebulaGraphDB] Closed and removed shared pool key={key}")
176+
cls._CLIENT_CACHE.pop(key, None)
177+
cls._CLIENT_REFCOUNT.pop(key, None)
178+
logger.info(f"[NebulaGraphDBSync] Closed & removed client key={key}")
304179

305180
@classmethod
306-
def close_all_shared_pools(cls):
307-
"""Force close all cached pools. Call this on graceful shutdown."""
308-
with cls._POOL_LOCK:
309-
for key, pool in list(cls._POOL_CACHE.items()):
181+
def close_all_shared_clients(cls):
182+
with cls._CLIENT_LOCK:
183+
for key, client in list(cls._CLIENT_CACHE.items()):
310184
try:
311-
pool.close()
185+
client.close()
312186
except Exception as e:
313-
logger.warning(f"[NebulaGraphDB] Error closing pool key={key}: {e}")
187+
logger.warning(f"[NebulaGraphDBSync] Error closing client {key}: {e}")
314188
finally:
315-
logger.info(f"[NebulaGraphDB] Closed pool key={key}")
316-
cls._POOL_CACHE.clear()
317-
cls._POOL_REFCOUNT.clear()
189+
logger.info(f"[NebulaGraphDBSync] Closed client key={key}")
190+
cls._CLIENT_CACHE.clear()
191+
cls._CLIENT_REFCOUNT.clear()
318192

319193
@require_python_package(
320194
import_name="nebulagraph_python",
321-
install_command="pip install ... @Tianxing",
195+
install_command="pip install nebulagraph-python>=5.1.1",
322196
install_link=".....",
323197
)
324198
def __init__(self, config: NebulaGraphDBConfig):
@@ -376,34 +250,35 @@ def __init__(self, config: NebulaGraphDBConfig):
376250

377251
# ---- NEW: pool acquisition strategy
378252
# Get or create a shared pool from the class-level cache
379-
self._pool_key, self.pool = self._get_or_create_shared_pool(config)
380-
self._owns_pool = True # We manage refcount for this instance
253+
self._client_key, self._client = self._get_or_create_shared_client(config)
254+
self._owns_client = True
381255

382256
# auto-create graph type / graph / index if needed
383-
if config.auto_create:
257+
if getattr(config, "auto_create", False):
384258
self._ensure_database_exists()
385259

386260
self.execute_query(f"SESSION SET GRAPH `{self.db_name}`")
387261

388262
# Create only if not exists
389263
self.create_index(dimensions=config.embedding_dimension)
390-
391264
logger.info("Connected to NebulaGraph successfully.")
392265

393266
@timed
394267
def execute_query(self, gql: str, timeout: float = 10.0, auto_set_db: bool = True):
395-
with self.pool.get() as client:
396-
try:
397-
if auto_set_db and self.db_name:
398-
client.execute(f"SESSION SET GRAPH `{self.db_name}`")
399-
return client.execute(gql, timeout=timeout)
400-
401-
except Exception as e:
402-
if "Session not found" in str(e) or "Connection not established" in str(e):
403-
logger.warning(f"[execute_query] {e!s}, replacing client...")
404-
self.pool.replace_client(client)
405-
return self.execute_query(gql, timeout, auto_set_db)
406-
raise
268+
try:
269+
if auto_set_db and self.db_name:
270+
self._client.execute(f"SESSION SET GRAPH `{self.db_name}`")
271+
return self._client.execute(gql, timeout=timeout)
272+
except Exception as e:
273+
emsg = str(e)
274+
if "Session not found" in emsg or "Connection not established" in emsg:
275+
logger.warning(f"[execute_query] {e!s}, retry once...")
276+
try:
277+
return self._client.execute(gql, timeout=timeout)
278+
except Exception:
279+
logger.exception("[execute_query] retry failed")
280+
raise
281+
raise
407282

408283
@timed
409284
def close(self):
@@ -414,13 +289,13 @@ def close(self):
414289
- If pool was acquired via shared cache, decrement refcount and close
415290
when the last owner releases it.
416291
"""
417-
if not self._owns_pool:
418-
logger.debug("[NebulaGraphDB] close() skipped (injected pool).")
292+
if not self._owns_client:
293+
logger.debug("[NebulaGraphDBSync] close() skipped (injected client).")
419294
return
420-
if self._pool_key:
421-
self._release_shared_pool(self._pool_key)
422-
self._pool_key = None
423-
self.pool = None
295+
if self._client_key:
296+
self._release_shared_client(self._client_key)
297+
self._client_key = None
298+
self._client = None
424299

425300
# NOTE: __del__ is best-effort; do not rely on GC order.
426301
def __del__(self):

0 commit comments

Comments
 (0)