3
3
4
4
from neo4j import AsyncDriver , AsyncGraphDatabase , AsyncSession , Record , RoutingControl
5
5
from neo4j .auth_management import AsyncAuthManagers
6
+ from neo4j .exceptions import (
7
+ AuthError ,
8
+ TransientError ,
9
+ ServiceUnavailable ,
10
+ SessionExpired ,
11
+ )
6
12
from nodestream .file_io import LazyLoadedArgument
7
13
8
14
from .query import Query
9
15
10
16
17
+ RETRYABLE_EXCEPTIONS = (TransientError , ServiceUnavailable , SessionExpired , AuthError )
18
+
19
+
11
20
def auth_provider_factory (
12
21
username : Union [str , LazyLoadedArgument ],
13
22
password : Union [str , LazyLoadedArgument ],
@@ -42,20 +51,37 @@ def from_configuration(
42
51
username : Union [str , LazyLoadedArgument ],
43
52
password : Union [str , LazyLoadedArgument ],
44
53
database_name : str = "neo4j" ,
45
- ** driver_kwargs
54
+ max_retry_attempts : int = 3 ,
55
+ ** driver_kwargs ,
46
56
):
47
- auth = AsyncAuthManagers .basic (auth_provider_factory (username , password ))
48
- driver = AsyncGraphDatabase .driver (uri , auth = auth , ** driver_kwargs )
49
- return cls (driver , database_name )
57
+ def driver_factory ():
58
+ auth = AsyncAuthManagers .basic (auth_provider_factory (username , password ))
59
+ return AsyncGraphDatabase .driver (uri , auth = auth , ** driver_kwargs )
60
+
61
+ return cls (driver_factory , database_name , max_retry_attempts )
50
62
51
- def __init__ (self , driver : AsyncDriver , database_name : str ) -> None :
52
- self .driver = driver
63
+ def __init__ (
64
+ self , driver_factory , database_name : str , max_retry_attempts : int
65
+ ) -> None :
66
+ self .driver_factory = driver_factory
53
67
self .database_name = database_name
54
68
self .logger = getLogger (self .__class__ .__name__ )
69
+ self .max_retry_attempts = max_retry_attempts
70
+ self ._driver = None
55
71
56
- async def execute (
57
- self , query : Query , log_result : bool = False , routing_ = RoutingControl .WRITE
58
- ) -> Iterable [Record ]:
72
+ def acquire_driver (self ) -> AsyncDriver :
73
+ self ._driver = self .driver_factory ()
74
+
75
+ @property
76
+ def driver (self ):
77
+ if self ._driver is None :
78
+ self .acquire_driver ()
79
+ return self ._driver
80
+
81
+ def session (self ) -> AsyncSession :
82
+ return self .driver .session (database = self .database_name )
83
+
84
+ def log_query_start (self , query : Query ):
59
85
self .logger .info (
60
86
"Executing Cypher Query to Neo4j" ,
61
87
extra = {
@@ -64,23 +90,43 @@ async def execute(
64
90
},
65
91
)
66
92
93
+ def log_record (self , record : Record ):
94
+ self .logger .info (
95
+ "Gathered Query Results" ,
96
+ extra = dict (** record , uri = self .driver ._pool .address .host ),
97
+ )
98
+
99
+ async def _execute_query (
100
+ self , query : Query , log_result : bool = False , routing_ = RoutingControl .WRITE
101
+ ) -> Record :
67
102
result = await self .driver .execute_query (
68
103
query .query_statement ,
69
104
query .parameters ,
70
105
database_ = self .database_name ,
71
106
routing_ = routing_ ,
72
107
)
108
+ records = result .records
73
109
if log_result :
74
- for record in result .records :
75
- self .logger .info (
76
- "Gathered Query Results" ,
77
- extra = dict (
78
- ** record ,
79
- query = query .query_statement ,
80
- uri = self .driver ._pool .address .host
81
- ),
82
- )
83
- return result .records
110
+ for record in records :
111
+ self .log_record (record )
84
112
85
- def session (self ) -> AsyncSession :
86
- return self .driver .session (database = self .database_name )
113
+ return records
114
+
115
+ async def execute (
116
+ self , query : Query , log_result : bool = False , routing_ = RoutingControl .WRITE
117
+ ) -> Iterable [Record ]:
118
+ self .log_query_start (query )
119
+ attempts = 0
120
+ while True :
121
+ attempts += 1
122
+ try :
123
+ return await self ._execute_query (query , log_result , routing_ )
124
+ except RETRYABLE_EXCEPTIONS as e :
125
+ self .logger .warning (
126
+ f"Error executing query, retrying. Attempt { attempts + 1 } " ,
127
+ exc_info = e ,
128
+ )
129
+ self .acquire_driver ()
130
+ if attempts >= self .max_retry_attempts :
131
+ message = f"Failed to execute after { self .max_retry_attempts } tries"
132
+ raise Exception (message ) from e
0 commit comments