Skip to content

Commit 991c887

Browse files
committed
add: poll gragh db
1 parent b946377 commit 991c887

File tree

6 files changed

+499
-0
lines changed

6 files changed

+499
-0
lines changed
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""
2+
Comparison: Regular vs Pooled Neo4j connections.
3+
4+
This script demonstrates the difference in connection management
5+
between regular Neo4j backend and the pooled version.
6+
"""
7+
8+
from memos.configs.graph_db import GraphDBConfigFactory
9+
from memos.graph_dbs.connection_pool import connection_pool
10+
from memos.graph_dbs.factory import GraphStoreFactory
11+
12+
13+
def create_graph_instance(backend: str, user_id: str):
14+
"""Create a graph database instance with specified backend."""
15+
config = GraphDBConfigFactory(
16+
backend=backend,
17+
config={
18+
"uri": "bolt://localhost:7687",
19+
"user": "neo4j",
20+
"password": "12345678",
21+
"db_name": "test_comparison",
22+
"user_name": f"user_{user_id}",
23+
"use_multi_db": False,
24+
"auto_create": False, # Skip auto-creation for demo
25+
"embedding_dimension": 768,
26+
},
27+
)
28+
return GraphStoreFactory.from_config(config)
29+
30+
31+
def demo_regular_connections():
32+
"""Demonstrate regular Neo4j connections (each instance creates own connection)."""
33+
print("\n=== Regular Neo4j Backend ===")
34+
instances = []
35+
36+
for i in range(3):
37+
print(f"Creating instance {i + 1}...")
38+
instance = create_graph_instance("neo4j", f"user_{i}")
39+
instances.append(instance)
40+
print(f"Instance {i + 1} created with separate connection")
41+
42+
print(f"Total instances created: {len(instances)}")
43+
print("Note: Each instance has its own database connection")
44+
45+
46+
def demo_pooled_connections():
47+
"""Demonstrate pooled Neo4j connections (shared connection pool)."""
48+
print("\n=== Neo4j Pooled Backend ===")
49+
print(f"Initial pool connections: {connection_pool.get_active_connections()}")
50+
51+
instances = []
52+
53+
for i in range(3):
54+
print(f"Creating instance {i + 1}...")
55+
instance = create_graph_instance("neo4j-pooled", f"user_{i}")
56+
instances.append(instance)
57+
print(f"Pool connections: {connection_pool.get_active_connections()}")
58+
59+
print(f"Total instances created: {len(instances)}")
60+
print(f"Shared connections in pool: {connection_pool.get_active_connections()}")
61+
print("Note: All instances share the same database connection!")
62+
63+
64+
def main():
65+
"""Run the comparison demo."""
66+
print("=== Neo4j Connection Management Comparison ===")
67+
68+
# Demo regular connections
69+
demo_regular_connections()
70+
71+
# Demo pooled connections
72+
demo_pooled_connections()
73+
74+
print("\n=== Summary ===")
75+
print("• Regular backend: Each instance = 1 connection")
76+
print("• Pooled backend: Multiple instances = 1 shared connection")
77+
print("• Pooled version reduces connection overhead for multi-user scenarios")
78+
79+
80+
if __name__ == "__main__":
81+
main()
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""
2+
Example: Using Neo4j connection pooling to reduce connection overhead.
3+
4+
This example demonstrates how to use the neo4j-pooled backend to share
5+
database connections across multiple users/memory instances.
6+
"""
7+
8+
from memos.configs.mem_cube import GeneralMemCubeConfig
9+
from memos.graph_dbs.connection_pool import connection_pool
10+
from memos.mem_cube.general import GeneralMemCube
11+
12+
13+
def create_user_cube(user_id: str, openai_api_key: str) -> GeneralMemCube:
14+
"""Create a memory cube for a user using pooled connections."""
15+
16+
config = GeneralMemCubeConfig(
17+
cube_id=f"user_{user_id}",
18+
text_mem={
19+
"backend": "tree_text",
20+
"config": {
21+
"extractor_llm": {
22+
"backend": "openai",
23+
"config": {
24+
"api_key": openai_api_key,
25+
"model_name": "gpt-4o-mini",
26+
},
27+
},
28+
"dispatcher_llm": {
29+
"backend": "openai",
30+
"config": {
31+
"api_key": openai_api_key,
32+
"model_name": "gpt-4o-mini",
33+
},
34+
},
35+
"graph_db": {
36+
"backend": "neo4j-pooled", # Use pooled version
37+
"config": {
38+
"uri": "bolt://localhost:7687",
39+
"user": "neo4j",
40+
"password": "12345678",
41+
"db_name": "shared_memos",
42+
"user_name": f"user_{user_id}",
43+
"use_multi_db": False,
44+
"auto_create": True,
45+
"embedding_dimension": 3072,
46+
},
47+
},
48+
"embedder": {
49+
"backend": "sentence_transformer",
50+
"config": {"model_name_or_path": "sentence-transformers/all-mpnet-base-v2"},
51+
},
52+
"reorganize": False,
53+
},
54+
},
55+
)
56+
57+
return GeneralMemCube(config)
58+
59+
60+
def main():
61+
"""Demonstrate connection pooling with multiple users."""
62+
63+
# Replace with your actual OpenAI API key
64+
openai_api_key = "your-openai-api-key-here"
65+
66+
print("=== Neo4j Connection Pooling Demo ===")
67+
print(f"Initial connections: {connection_pool.get_active_connections()}")
68+
69+
# Create multiple user cubes
70+
users = ["alice", "bob", "charlie"]
71+
cubes = {}
72+
73+
for user_id in users:
74+
print(f"\nCreating cube for user: {user_id}")
75+
cubes[user_id] = create_user_cube(user_id, openai_api_key)
76+
print(f"Active connections: {connection_pool.get_active_connections()}")
77+
78+
# Add some memories for each user
79+
memories = {
80+
"alice": "Alice loves hiking in the mountains.",
81+
"bob": "Bob is a software engineer who enjoys cooking.",
82+
"charlie": "Charlie plays guitar and loves jazz music.",
83+
}
84+
85+
print("\n=== Adding memories ===")
86+
for user_id, memory in memories.items():
87+
if cubes[user_id].text_mem:
88+
cubes[user_id].text_mem.add(memory)
89+
print(f"Added memory for {user_id}")
90+
91+
# Search memories
92+
print("\n=== Searching memories ===")
93+
for user_id in users:
94+
if cubes[user_id].text_mem:
95+
results = cubes[user_id].text_mem.search("hobbies", top_k=1)
96+
if results:
97+
print(f"{user_id}'s memory: {results[0].memory}")
98+
99+
print(f"\nFinal active connections: {connection_pool.get_active_connections()}")
100+
print("Note: All users share the same database connection!")
101+
102+
103+
if __name__ == "__main__":
104+
main()
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""Connection pool manager for graph databases."""
2+
3+
import threading
4+
5+
from typing import Any
6+
7+
from memos.log import get_logger
8+
9+
10+
logger = get_logger(__name__)
11+
12+
13+
class Neo4jConnectionPool:
14+
"""Singleton connection pool for Neo4j databases."""
15+
16+
_instance = None
17+
_lock = threading.Lock()
18+
19+
def __new__(cls):
20+
if cls._instance is None:
21+
with cls._lock:
22+
if cls._instance is None:
23+
cls._instance = super().__new__(cls)
24+
cls._instance._initialized = False
25+
return cls._instance
26+
27+
def __init__(self):
28+
if not getattr(self, "_initialized", False):
29+
self._drivers: dict[str, Any] = {}
30+
self._driver_lock = threading.Lock()
31+
self._initialized = True
32+
33+
def get_driver(self, uri: str, user: str, password: str):
34+
"""Get or create a driver for the given connection parameters."""
35+
connection_key = f"{uri}:{user}"
36+
37+
if connection_key not in self._drivers:
38+
with self._driver_lock:
39+
if connection_key not in self._drivers:
40+
from neo4j import GraphDatabase
41+
42+
driver = GraphDatabase.driver(uri, auth=(user, password))
43+
self._drivers[connection_key] = driver
44+
logger.info(f"Created new Neo4j driver for {connection_key}")
45+
else:
46+
logger.debug(f"Using existing Neo4j driver for {connection_key}")
47+
else:
48+
logger.debug(f"Reusing existing Neo4j driver for {connection_key}")
49+
50+
return self._drivers[connection_key]
51+
52+
def close_all(self):
53+
"""Close all connections in the pool."""
54+
with self._driver_lock:
55+
for connection_key, driver in self._drivers.items():
56+
try:
57+
driver.close()
58+
logger.info(f"Closed Neo4j driver for {connection_key}")
59+
except Exception as e:
60+
logger.error(f"Error closing driver for {connection_key}: {e}")
61+
self._drivers.clear()
62+
63+
def close_driver(self, uri: str, user: str):
64+
"""Close a specific driver."""
65+
connection_key = f"{uri}:{user}"
66+
with self._driver_lock:
67+
if connection_key in self._drivers:
68+
try:
69+
self._drivers[connection_key].close()
70+
del self._drivers[connection_key]
71+
logger.info(f"Closed and removed Neo4j driver for {connection_key}")
72+
except Exception as e:
73+
logger.error(f"Error closing driver for {connection_key}: {e}")
74+
75+
def get_active_connections(self) -> int:
76+
"""Get the number of active connections."""
77+
return len(self._drivers)
78+
79+
80+
# Global connection pool instance
81+
connection_pool = Neo4jConnectionPool()

src/memos/graph_dbs/factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
from memos.graph_dbs.base import BaseGraphDB
55
from memos.graph_dbs.neo4j import Neo4jGraphDB
66
from memos.graph_dbs.neo4j_community import Neo4jCommunityGraphDB
7+
from memos.graph_dbs.neo4j_pooled import Neo4jPooledGraphDB
78

89

910
class GraphStoreFactory(BaseGraphDB):
1011
"""Factory for creating graph store instances."""
1112

1213
backend_to_class: ClassVar[dict[str, Any]] = {
1314
"neo4j": Neo4jGraphDB,
15+
"neo4j-pooled": Neo4jPooledGraphDB,
1416
"neo4j-community": Neo4jCommunityGraphDB,
1517
}
1618

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""Neo4j GraphDB implementation with connection pooling."""
2+
3+
from memos.configs.graph_db import Neo4jGraphDBConfig
4+
from memos.dependency import require_python_package
5+
from memos.graph_dbs.connection_pool import connection_pool
6+
from memos.graph_dbs.neo4j import Neo4jGraphDB
7+
from memos.log import get_logger
8+
9+
10+
logger = get_logger(__name__)
11+
12+
13+
class Neo4jPooledGraphDB(Neo4jGraphDB):
14+
"""Neo4j-based implementation with connection pooling to reduce connection overhead."""
15+
16+
@require_python_package(
17+
import_name="neo4j",
18+
install_command="pip install neo4j",
19+
install_link="https://neo4j.com/docs/python-manual/current/install/",
20+
)
21+
def __init__(self, config: Neo4jGraphDBConfig):
22+
"""Neo4j-based implementation with connection pooling.
23+
24+
This implementation uses a shared connection pool to reuse database connections
25+
across multiple instances, reducing the overhead of creating new connections
26+
for each user.
27+
28+
Tenant Modes:
29+
- use_multi_db = True:
30+
Dedicated Database Mode (Multi-Database Multi-Tenant).
31+
Each tenant or logical scope uses a separate Neo4j database.
32+
`db_name` is the specific tenant database.
33+
`user_name` can be None (optional).
34+
35+
- use_multi_db = False:
36+
Shared Database Multi-Tenant Mode.
37+
All tenants share a single Neo4j database.
38+
`db_name` is the shared database.
39+
`user_name` is required to isolate each tenant's data at the node level.
40+
All node queries will enforce `user_name` in WHERE conditions and store it in metadata,
41+
but it will be removed automatically before returning to external consumers.
42+
"""
43+
self.config = config
44+
45+
# Use connection pool instead of creating new driver
46+
self.driver = connection_pool.get_driver(config.uri, config.user, config.password)
47+
self.db_name = config.db_name
48+
self.user_name = config.user_name
49+
50+
self.system_db_name = "system" if config.use_multi_db else config.db_name
51+
if config.auto_create:
52+
self._ensure_database_exists()
53+
54+
# Create only if not exists
55+
self.create_index(dimensions=config.embedding_dimension)
56+
57+
logger.debug(
58+
f"Neo4jPooledGraphDB initialized for {config.uri}:{config.user}, "
59+
f"total active connections: {connection_pool.get_active_connections()}"
60+
)

0 commit comments

Comments
 (0)