11import json
22import logging
33import re
4- from typing import Any , Literal , Optional
4+ from typing import Any , Literal
55
66from fastmcp .exceptions import ToolError
77from fastmcp .tools .tool import ToolResult , TextContent
1010from neo4j import (
1111 AsyncDriver ,
1212 AsyncGraphDatabase ,
13- AsyncResult ,
14- AsyncTransaction ,
13+ RoutingControl
1514)
1615from neo4j .exceptions import ClientError , Neo4jError
1716from pydantic import Field
@@ -27,19 +26,6 @@ def _format_namespace(namespace: str) -> str:
2726 else :
2827 return ""
2928
30- async def _read (tx : AsyncTransaction , query : str , params : dict [str , Any ]) -> str :
31- raw_results = await tx .run (query , params )
32- eager_results = await raw_results .to_eager_result ()
33-
34- return json .dumps ([r .data () for r in eager_results .records ], default = str )
35-
36-
37- async def _write (
38- tx : AsyncTransaction , query : str , params : dict [str , Any ]
39- ) -> AsyncResult :
40- return await tx .run (query , params )
41-
42-
4329def _is_write_query (query : str ) -> bool :
4430 """Check if the query is a write query."""
4531 return (
@@ -135,18 +121,19 @@ def clean_schema(schema: dict) -> dict:
135121
136122
137123 try :
138- async with neo4j_driver .session (database = database ) as session :
139- results_json_str = await session .execute_read (
140- _read , get_schema_query , dict ()
141- )
142124
143- logger .debug (f"Read query returned { len (results_json_str )} rows" )
125+ results_json_str = await neo4j_driver .execute_query (get_schema_query ,
126+ routing_control = RoutingControl .READ ,
127+ database_ = database ,
128+ result_transformer_ = lambda r : r .data ())
129+
130+ logger .debug (f"Read query returned { len (results_json_str )} rows" )
131+
132+ schema_clean = clean_schema (results_json_str [0 ].get ('value' ))
144133
145- schema = json .loads (results_json_str )[0 ].get ('value' )
146- schema_clean = clean_schema (schema )
147- schema_clean_str = json .dumps (schema_clean )
134+ schema_clean_str = json .dumps (schema_clean , default = str )
148135
149- return ToolResult (content = [TextContent (type = "text" , text = schema_clean_str )])
136+ return ToolResult (content = [TextContent (type = "text" , text = schema_clean_str )])
150137
151138 except ClientError as e :
152139 if "Neo.ClientError.Procedure.ProcedureNotFound" in str (e ):
@@ -180,12 +167,17 @@ async def read_neo4j_cypher(
180167 raise ValueError ("Only MATCH queries are allowed for read-query" )
181168
182169 try :
183- async with neo4j_driver .session (database = database ) as session :
184- results_json_str = await session .execute_read (_read , query , params )
170+ results = await neo4j_driver .execute_query (query ,
171+ parameters_ = params ,
172+ routing_control = RoutingControl .READ ,
173+ database_ = database ,
174+ result_transformer_ = lambda r : r .data ())
175+
176+ results_json_str = json .dumps (results , default = str )
185177
186- logger .debug (f"Read query returned { len (results_json_str )} rows" )
178+ logger .debug (f"Read query returned { len (results_json_str )} rows" )
187179
188- return ToolResult (content = [TextContent (type = "text" , text = results_json_str )])
180+ return ToolResult (content = [TextContent (type = "text" , text = results_json_str )])
189181
190182 except Neo4jError as e :
191183 logger .error (f"Neo4j Error executing read query: { e } \n { query } \n { params } " )
@@ -214,11 +206,13 @@ async def write_neo4j_cypher(
214206 raise ValueError ("Only write queries are allowed for write-query" )
215207
216208 try :
217- async with neo4j_driver .session (database = database ) as session :
218- raw_results = await session .execute_write (_write , query , params )
219- counters_json_str = json .dumps (
220- raw_results ._summary .counters .__dict__ , default = str
221- )
209+ _ , summary , _ = await neo4j_driver .execute_query (query ,
210+ parameters_ = params ,
211+ routing_control = RoutingControl .WRITE ,
212+ database_ = database ,
213+ )
214+
215+ counters_json_str = json .dumps (summary .counters .__dict__ , default = str )
222216
223217 logger .debug (f"Write query affected { counters_json_str } " )
224218
0 commit comments