Skip to content

Commit f5e185e

Browse files
committed
fix: fix tomasonjo (neo4j) issues
1 parent 0bb326c commit f5e185e

File tree

6 files changed

+32
-19
lines changed

6 files changed

+32
-19
lines changed

coverage-badge.svg

Lines changed: 1 addition & 1 deletion
Loading

synalinks/src/knowledge_bases/database_adapters/database_adapter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import copy
44
import re
5+
from typing import Any
6+
from typing import Dict
57

68
from synalinks.src.utils.naming import to_snake_case
79

@@ -72,7 +74,7 @@ async def update(self, data_model, threshold=0.8):
7274
f"{self.__class__} should implement the `update()` method"
7375
)
7476

75-
async def query(self, query):
77+
async def query(self, query: str, params: Dict[str, Any] = None, **kwargs):
7678
raise NotImplementedError(
7779
f"{self.__class__} should implement the `query()` method"
7880
)

synalinks/src/knowledge_bases/database_adapters/neo4j_adapter.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(
3131
index_name=index_name,
3232
embedding_model=embedding_model,
3333
)
34+
self.db = os.getenv("NEO4J_DATABASE", "neo4j")
3435
self.username = os.getenv("NEO4J_USERNAME", "neo4j")
3536
self.password = os.getenv("NEO4J_PASSWORD", "neo4j")
3637

@@ -52,10 +53,14 @@ def __init__(
5253

5354
if wipe_on_start:
5455
asyncio.get_event_loop().run_until_complete(
55-
self.query("MATCH (n)-[r]->() DELETE n, r"),
56-
)
57-
asyncio.get_event_loop().run_until_complete(
58-
self.query("MATCH (m) DELETE m"),
56+
self.query(
57+
"""
58+
MATCH (n)
59+
CALL (n) {
60+
DETACH DELETE n
61+
} IN TRANSACTIONS OF 10000 ROWS
62+
"""
63+
)
5964
)
6065

6166
query = "\n".join(
@@ -101,16 +106,17 @@ def __init__(
101106
self.query("CALL db.awaitIndexes(300)"),
102107
)
103108

104-
async def query(self, query: str, params: Dict[str, Any] = None):
109+
async def query(self, query: str, params: Dict[str, Any] = None, **kwargs):
105110
driver = neo4j.GraphDatabase.driver(
106111
self.index_name, auth=(self.username, self.password)
107112
)
108113
try:
109-
if params:
110-
records, _, _ = driver.execute_query(query, **params, database_="neo4j")
111-
else:
112-
records, _, _ = driver.execute_query(query, database_="neo4j")
113-
return records
114+
with driver.session(database=self.db) as session:
115+
if params:
116+
result = session.run(query, **params, **kwargs)
117+
else:
118+
result = session.run(query, **kwargs)
119+
return list(result)
114120
finally:
115121
driver.close()
116122

synalinks/src/knowledge_bases/knowledge_base.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# License Apache 2.0: (c) 2025 Yoan Sallami (Synalinks Team)
22

3+
from typing import Any
4+
from typing import Dict
5+
36
from synalinks.src.api_export import synalinks_export
47
from synalinks.src.backend import is_symbolic_data_model
58
from synalinks.src.backend.config import maybe_initialize_telemetry
@@ -33,8 +36,9 @@ class IsPartOf(synalinks.Relation):
3336
model="ollama/mxbai-embed-large"
3437
)
3538
36-
os.environ["NEO4J_USERNAME"] = "your-neo4j-username"
37-
os.environ["NEO4J_PASSWORD"] = "your-neo4j-password"
39+
os.environ["NEO4J_DATABASE"] = "your-neo4j-db" # (Default to "neo4j")
40+
os.environ["NEO4J_USERNAME"] = "your-neo4j-username" # (Default to "neo4j")
41+
os.environ["NEO4J_PASSWORD"] = "your-neo4j-password" # (Default to "neo4j")
3842
3943
knowledge_base = synalinks.KnowledgeBase(
4044
index_name="neo4j://localhost:7687",
@@ -107,7 +111,7 @@ async def update(
107111
maybe_initialize_telemetry()
108112
return await self.adapter.update(data_model)
109113

110-
async def query(self, query: str):
114+
async def query(self, query: str, params: Dict[str, Any] = None, **kwargs):
111115
"""Execute a query against the knowledge base.
112116
113117
Args:
@@ -118,7 +122,7 @@ async def query(self, query: str):
118122
(GenericResult): the query results
119123
"""
120124
maybe_initialize_telemetry()
121-
return await self.adapter.query(query)
125+
return await self.adapter.query(query, params=params, **kwargs)
122126

123127
async def similarity_search(
124128
self,

synalinks/src/modules/agents/react_agent.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from synalinks.src.programs.program import Program
99
from synalinks.src.utils.tool_utils import Tool
1010

11-
1211
_fn_END = "finish"
1312

1413

@@ -217,7 +216,9 @@ def __init__(
217216
for fn in self.functions:
218217
self.labels.append(Tool(fn).name())
219218

220-
assert _fn_END not in self.labels, f"'{_fn_END}' is a reserved keyword and cannot be used as function name"
219+
assert _fn_END not in self.labels, (
220+
f"'{_fn_END}' is a reserved keyword and cannot be used as function name"
221+
)
221222

222223
self.labels.append(_fn_END)
223224

synalinks/src/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from synalinks.src.api_export import synalinks_export
44

55
# Unique source of truth for the version number.
6-
__version__ = "0.3.006"
6+
__version__ = "0.3.007"
77

88

99
@synalinks_export("synalinks.version")

0 commit comments

Comments
 (0)