1
+ import asyncio
1
2
from logging import getLogger
2
3
from typing import Awaitable , Iterable , Tuple , Union
3
4
4
5
from neo4j import AsyncDriver , AsyncGraphDatabase , AsyncSession , Record , RoutingControl
5
6
from neo4j .auth_management import AsyncAuthManagers
7
+ from neo4j .exceptions import (
8
+ AuthError ,
9
+ ServiceUnavailable ,
10
+ SessionExpired ,
11
+ TransientError ,
12
+ )
6
13
from nodestream .file_io import LazyLoadedArgument
7
14
8
15
from .query import Query
9
16
17
+ RETRYABLE_EXCEPTIONS = (TransientError , ServiceUnavailable , SessionExpired , AuthError )
18
+
10
19
11
20
def auth_provider_factory (
12
21
username : Union [str , LazyLoadedArgument ],
@@ -42,20 +51,40 @@ def from_configuration(
42
51
username : Union [str , LazyLoadedArgument ],
43
52
password : Union [str , LazyLoadedArgument ],
44
53
database_name : str = "neo4j" ,
45
- ** driver_kwargs
54
+ max_retry_attempts : int = 3 ,
55
+ retry_factor : int = 1 ,
56
+ ** driver_kwargs ,
46
57
):
47
- auth = AsyncAuthManagers .basic (auth_provider_factory (username , password ))
48
- driver = AsyncGraphDatabase .driver (uri , auth = auth , ** driver_kwargs )
49
- return cls (driver , database_name )
58
+ def driver_factory ():
59
+ auth = AsyncAuthManagers .basic (auth_provider_factory (username , password ))
60
+ return AsyncGraphDatabase .driver (uri , auth = auth , ** driver_kwargs )
61
+
62
+ return cls (driver_factory , database_name , max_retry_attempts , retry_factor )
50
63
51
- def __init__ (self , driver : AsyncDriver , database_name : str ) -> None :
52
- self .driver = driver
64
+ def __init__ (
65
+ self ,
66
+ driver_factory ,
67
+ database_name : str ,
68
+ max_retry_attempts : int = 3 ,
69
+ retry_factor : float = 1 ,
70
+ ) -> None :
71
+ self .driver_factory = driver_factory
53
72
self .database_name = database_name
54
73
self .logger = getLogger (self .__class__ .__name__ )
74
+ self .max_retry_attempts = max_retry_attempts
75
+ self .retry_factor = retry_factor
76
+ self ._driver = None
55
77
56
- async def execute (
57
- self , query : Query , log_result : bool = False , routing_ = RoutingControl .WRITE
58
- ) -> Iterable [Record ]:
78
+ def acquire_driver (self ) -> AsyncDriver :
79
+ self ._driver = self .driver_factory ()
80
+
81
+ @property
82
+ def driver (self ):
83
+ if self ._driver is None :
84
+ self .acquire_driver ()
85
+ return self ._driver
86
+
87
+ def log_query_start (self , query : Query ):
59
88
self .logger .info (
60
89
"Executing Cypher Query to Neo4j" ,
61
90
extra = {
@@ -64,23 +93,52 @@ async def execute(
64
93
},
65
94
)
66
95
96
+ def log_record (self , record : Record ):
97
+ self .logger .info (
98
+ "Gathered Query Results" ,
99
+ extra = dict (** record , uri = self .driver ._pool .address .host ),
100
+ )
101
+
102
+ async def _execute_query (
103
+ self ,
104
+ query : Query ,
105
+ log_result : bool = False ,
106
+ routing_ = RoutingControl .WRITE ,
107
+ ) -> Record :
67
108
result = await self .driver .execute_query (
68
109
query .query_statement ,
69
110
query .parameters ,
70
111
database_ = self .database_name ,
71
112
routing_ = routing_ ,
72
113
)
114
+ records = result .records
73
115
if log_result :
74
- for record in result .records :
75
- self .logger .info (
76
- "Gathered Query Results" ,
77
- extra = dict (
78
- ** record ,
79
- query = query .query_statement ,
80
- uri = self .driver ._pool .address .host
81
- ),
116
+ for record in records :
117
+ self .log_record (record )
118
+
119
+ return records
120
+
121
+ async def execute (
122
+ self ,
123
+ query : Query ,
124
+ log_result : bool = False ,
125
+ routing_ = RoutingControl .WRITE ,
126
+ ) -> Iterable [Record ]:
127
+ self .log_query_start (query )
128
+ attempts = 0
129
+ while True :
130
+ attempts += 1
131
+ try :
132
+ return await self ._execute_query (query , log_result , routing_ )
133
+ except RETRYABLE_EXCEPTIONS as e :
134
+ self .logger .warning (
135
+ f"Error executing query, retrying. Attempt { attempts + 1 } " ,
136
+ exc_info = e ,
82
137
)
83
- return result .records
138
+ await asyncio .sleep (self .retry_factor * attempts )
139
+ self .acquire_driver ()
140
+ if attempts >= self .max_retry_attempts :
141
+ raise e
84
142
85
143
def session (self ) -> AsyncSession :
86
144
return self .driver .session (database = self .database_name )
0 commit comments