Skip to content

Commit 35799a0

Browse files
committed
handle spurious errors with retry
1 parent 7d70907 commit 35799a0

File tree

2 files changed

+68
-22
lines changed

2 files changed

+68
-22
lines changed

nodestream_plugin_neo4j/neo4j_database.py

Lines changed: 67 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,20 @@
33

44
from neo4j import AsyncDriver, AsyncGraphDatabase, AsyncSession, Record, RoutingControl
55
from neo4j.auth_management import AsyncAuthManagers
6+
from neo4j.exceptions import (
7+
AuthError,
8+
TransientError,
9+
ServiceUnavailable,
10+
SessionExpired,
11+
)
612
from nodestream.file_io import LazyLoadedArgument
713

814
from .query import Query
915

1016

17+
RETRYABLE_EXCEPTIONS = (TransientError, ServiceUnavailable, SessionExpired, AuthError)
18+
19+
1120
def auth_provider_factory(
1221
username: Union[str, LazyLoadedArgument],
1322
password: Union[str, LazyLoadedArgument],
@@ -42,20 +51,37 @@ def from_configuration(
4251
username: Union[str, LazyLoadedArgument],
4352
password: Union[str, LazyLoadedArgument],
4453
database_name: str = "neo4j",
45-
**driver_kwargs
54+
max_retry_attempts: int = 3,
55+
**driver_kwargs,
4656
):
47-
auth = AsyncAuthManagers.basic(auth_provider_factory(username, password))
48-
driver = AsyncGraphDatabase.driver(uri, auth=auth, **driver_kwargs)
49-
return cls(driver, database_name)
57+
def driver_factory():
58+
auth = AsyncAuthManagers.basic(auth_provider_factory(username, password))
59+
return AsyncGraphDatabase.driver(uri, auth=auth, **driver_kwargs)
60+
61+
return cls(driver_factory, database_name, max_retry_attempts)
5062

51-
def __init__(self, driver: AsyncDriver, database_name: str) -> None:
52-
self.driver = driver
63+
def __init__(
64+
self, driver_factory, database_name: str, max_retry_attempts: int
65+
) -> None:
66+
self.driver_factory = driver_factory
5367
self.database_name = database_name
5468
self.logger = getLogger(self.__class__.__name__)
69+
self.max_retry_attempts = max_retry_attempts
70+
self._driver = None
5571

56-
async def execute(
57-
self, query: Query, log_result: bool = False, routing_=RoutingControl.WRITE
58-
) -> Iterable[Record]:
72+
def acquire_driver(self) -> AsyncDriver:
73+
self._driver = self.driver_factory()
74+
75+
@property
76+
def driver(self):
77+
if self._driver is None:
78+
self.acquire_driver()
79+
return self._driver
80+
81+
def session(self) -> AsyncSession:
82+
return self.driver.session(database=self.database_name)
83+
84+
def log_query_start(self, query: Query):
5985
self.logger.info(
6086
"Executing Cypher Query to Neo4j",
6187
extra={
@@ -64,23 +90,43 @@ async def execute(
6490
},
6591
)
6692

93+
def log_record(self, record: Record):
94+
self.logger.info(
95+
"Gathered Query Results",
96+
extra=dict(**record, uri=self.driver._pool.address.host),
97+
)
98+
99+
async def _execute_query(
100+
self, query: Query, log_result: bool = False, routing_=RoutingControl.WRITE
101+
) -> Record:
67102
result = await self.driver.execute_query(
68103
query.query_statement,
69104
query.parameters,
70105
database_=self.database_name,
71106
routing_=routing_,
72107
)
108+
records = result.records
73109
if log_result:
74-
for record in result.records:
75-
self.logger.info(
76-
"Gathered Query Results",
77-
extra=dict(
78-
**record,
79-
query=query.query_statement,
80-
uri=self.driver._pool.address.host
81-
),
82-
)
83-
return result.records
110+
for record in records:
111+
self.log_record(record)
84112

85-
def session(self) -> AsyncSession:
86-
return self.driver.session(database=self.database_name)
113+
return records
114+
115+
async def execute(
116+
self, query: Query, log_result: bool = False, routing_=RoutingControl.WRITE
117+
) -> Iterable[Record]:
118+
self.log_query_start(query)
119+
attempts = 0
120+
while True:
121+
attempts += 1
122+
try:
123+
return await self._execute_query(query, log_result, routing_)
124+
except RETRYABLE_EXCEPTIONS as e:
125+
self.logger.warning(
126+
f"Error executing query, retrying. Attempt {attempts + 1}",
127+
exc_info=e,
128+
)
129+
self.acquire_driver()
130+
if attempts >= self.max_retry_attempts:
131+
message = f"Failed to execute after {self.max_retry_attempts} tries"
132+
raise Exception(message) from e

tests/unit/test_neo4j_database.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
@pytest.fixture
1414
def database_connection(mocker):
15-
return Neo4jDatabaseConnection(mocker.AsyncMock(AsyncDriver), "neo4j")
15+
return Neo4jDatabaseConnection(mocker.AsyncMock(AsyncDriver), "neo4j", 3)
1616

1717

1818
@pytest.mark.asyncio

0 commit comments

Comments
 (0)