1
+ import asyncio
1
2
from logging import getLogger
2
3
from typing import Awaitable , Iterable , Tuple , Union
3
4
4
5
from neo4j import AsyncDriver , AsyncGraphDatabase , AsyncSession , Record , RoutingControl
5
6
from neo4j .auth_management import AsyncAuthManagers
6
7
from neo4j .exceptions import (
7
8
AuthError ,
8
- TransientError ,
9
9
ServiceUnavailable ,
10
10
SessionExpired ,
11
+ TransientError ,
11
12
)
12
13
from nodestream .file_io import LazyLoadedArgument
13
14
14
15
from .query import Query
15
16
16
-
17
17
RETRYABLE_EXCEPTIONS = (TransientError , ServiceUnavailable , SessionExpired , AuthError )
18
18
19
19
@@ -52,21 +52,27 @@ def from_configuration(
52
52
password : Union [str , LazyLoadedArgument ],
53
53
database_name : str = "neo4j" ,
54
54
max_retry_attempts : int = 3 ,
55
+ retry_factor : int = 1 ,
55
56
** driver_kwargs ,
56
57
):
57
58
def driver_factory ():
58
59
auth = AsyncAuthManagers .basic (auth_provider_factory (username , password ))
59
60
return AsyncGraphDatabase .driver (uri , auth = auth , ** driver_kwargs )
60
61
61
- return cls (driver_factory , database_name , max_retry_attempts )
62
+ return cls (driver_factory , database_name , max_retry_attempts , retry_factor )
62
63
63
64
def __init__ (
64
- self , driver_factory , database_name : str , max_retry_attempts : int
65
+ self ,
66
+ driver_factory ,
67
+ database_name : str ,
68
+ max_retry_attempts : int = 3 ,
69
+ retry_factor : float = 1 ,
65
70
) -> None :
66
71
self .driver_factory = driver_factory
67
72
self .database_name = database_name
68
73
self .logger = getLogger (self .__class__ .__name__ )
69
74
self .max_retry_attempts = max_retry_attempts
75
+ self .retry_factor = retry_factor
70
76
self ._driver = None
71
77
72
78
def acquire_driver (self ) -> AsyncDriver :
@@ -97,7 +103,10 @@ def log_record(self, record: Record):
97
103
)
98
104
99
105
async def _execute_query (
100
- self , query : Query , log_result : bool = False , routing_ = RoutingControl .WRITE
106
+ self ,
107
+ query : Query ,
108
+ log_result : bool = False ,
109
+ routing_ = RoutingControl .WRITE ,
101
110
) -> Record :
102
111
result = await self .driver .execute_query (
103
112
query .query_statement ,
@@ -113,7 +122,10 @@ async def _execute_query(
113
122
return records
114
123
115
124
async def execute (
116
- self , query : Query , log_result : bool = False , routing_ = RoutingControl .WRITE
125
+ self ,
126
+ query : Query ,
127
+ log_result : bool = False ,
128
+ routing_ = RoutingControl .WRITE ,
117
129
) -> Iterable [Record ]:
118
130
self .log_query_start (query )
119
131
attempts = 0
@@ -126,7 +138,7 @@ async def execute(
126
138
f"Error executing query, retrying. Attempt { attempts + 1 } " ,
127
139
exc_info = e ,
128
140
)
141
+ await asyncio .sleep (self .retry_factor * attempts )
129
142
self .acquire_driver ()
130
143
if attempts >= self .max_retry_attempts :
131
- message = f"Failed to execute after { self .max_retry_attempts } tries"
132
- raise Exception (message ) from e
144
+ raise e
0 commit comments