Skip to content

Commit 149b8e5

Browse files
authored
Merge pull request #22 from nodestream-proj/handle-spurious-errors
Handle spurious errors
2 parents 7d70907 + 8b762da commit 149b8e5

File tree

2 files changed

+116
-29
lines changed

2 files changed

+116
-29
lines changed

nodestream_plugin_neo4j/neo4j_database.py

Lines changed: 76 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
1+
import asyncio
12
from logging import getLogger
23
from typing import Awaitable, Iterable, Tuple, Union
34

45
from neo4j import AsyncDriver, AsyncGraphDatabase, AsyncSession, Record, RoutingControl
56
from neo4j.auth_management import AsyncAuthManagers
7+
from neo4j.exceptions import (
8+
AuthError,
9+
ServiceUnavailable,
10+
SessionExpired,
11+
TransientError,
12+
)
613
from nodestream.file_io import LazyLoadedArgument
714

815
from .query import Query
916

17+
RETRYABLE_EXCEPTIONS = (TransientError, ServiceUnavailable, SessionExpired, AuthError)
18+
1019

1120
def auth_provider_factory(
1221
username: Union[str, LazyLoadedArgument],
@@ -42,20 +51,40 @@ def from_configuration(
4251
username: Union[str, LazyLoadedArgument],
4352
password: Union[str, LazyLoadedArgument],
4453
database_name: str = "neo4j",
45-
**driver_kwargs
54+
max_retry_attempts: int = 3,
55+
retry_factor: int = 1,
56+
**driver_kwargs,
4657
):
47-
auth = AsyncAuthManagers.basic(auth_provider_factory(username, password))
48-
driver = AsyncGraphDatabase.driver(uri, auth=auth, **driver_kwargs)
49-
return cls(driver, database_name)
58+
def driver_factory():
59+
auth = AsyncAuthManagers.basic(auth_provider_factory(username, password))
60+
return AsyncGraphDatabase.driver(uri, auth=auth, **driver_kwargs)
61+
62+
return cls(driver_factory, database_name, max_retry_attempts, retry_factor)
5063

51-
def __init__(self, driver: AsyncDriver, database_name: str) -> None:
52-
self.driver = driver
64+
def __init__(
65+
self,
66+
driver_factory,
67+
database_name: str,
68+
max_retry_attempts: int = 3,
69+
retry_factor: float = 1,
70+
) -> None:
71+
self.driver_factory = driver_factory
5372
self.database_name = database_name
5473
self.logger = getLogger(self.__class__.__name__)
74+
self.max_retry_attempts = max_retry_attempts
75+
self.retry_factor = retry_factor
76+
self._driver = None
5577

56-
async def execute(
57-
self, query: Query, log_result: bool = False, routing_=RoutingControl.WRITE
58-
) -> Iterable[Record]:
78+
def acquire_driver(self) -> AsyncDriver:
79+
self._driver = self.driver_factory()
80+
81+
@property
82+
def driver(self):
83+
if self._driver is None:
84+
self.acquire_driver()
85+
return self._driver
86+
87+
def log_query_start(self, query: Query):
5988
self.logger.info(
6089
"Executing Cypher Query to Neo4j",
6190
extra={
@@ -64,23 +93,52 @@ async def execute(
6493
},
6594
)
6695

96+
def log_record(self, record: Record):
97+
self.logger.info(
98+
"Gathered Query Results",
99+
extra=dict(**record, uri=self.driver._pool.address.host),
100+
)
101+
102+
async def _execute_query(
103+
self,
104+
query: Query,
105+
log_result: bool = False,
106+
routing_=RoutingControl.WRITE,
107+
) -> Record:
67108
result = await self.driver.execute_query(
68109
query.query_statement,
69110
query.parameters,
70111
database_=self.database_name,
71112
routing_=routing_,
72113
)
114+
records = result.records
73115
if log_result:
74-
for record in result.records:
75-
self.logger.info(
76-
"Gathered Query Results",
77-
extra=dict(
78-
**record,
79-
query=query.query_statement,
80-
uri=self.driver._pool.address.host
81-
),
116+
for record in records:
117+
self.log_record(record)
118+
119+
return records
120+
121+
async def execute(
122+
self,
123+
query: Query,
124+
log_result: bool = False,
125+
routing_=RoutingControl.WRITE,
126+
) -> Iterable[Record]:
127+
self.log_query_start(query)
128+
attempts = 0
129+
while True:
130+
attempts += 1
131+
try:
132+
return await self._execute_query(query, log_result, routing_)
133+
except RETRYABLE_EXCEPTIONS as e:
134+
self.logger.warning(
135+
f"Error executing query, retrying. Attempt {attempts + 1}",
136+
exc_info=e,
82137
)
83-
return result.records
138+
await asyncio.sleep(self.retry_factor * attempts)
139+
self.acquire_driver()
140+
if attempts >= self.max_retry_attempts:
141+
raise e
84142

85143
def session(self) -> AsyncSession:
86144
return self.driver.session(database=self.database_name)

tests/unit/test_neo4j_database.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
from hamcrest import assert_that, equal_to
33
from neo4j import AsyncDriver, RoutingControl
4+
from neo4j.exceptions import TransientError
45
from nodestream.file_io import LazyLoadedArgument
56

67
from nodestream_plugin_neo4j.neo4j_database import (
@@ -9,30 +10,58 @@
910
)
1011
from nodestream_plugin_neo4j.query import Query
1112

13+
A_QUERY = Query("MATCH (n) RETURN n LIMIT $limit", {"limit": 2})
14+
SOME_RECORDS = [
15+
{"n": {"name": "foo"}},
16+
{"n": {"name": "bar"}},
17+
]
18+
1219

1320
@pytest.fixture
1421
def database_connection(mocker):
15-
return Neo4jDatabaseConnection(mocker.AsyncMock(AsyncDriver), "neo4j")
22+
return Neo4jDatabaseConnection(
23+
lambda: mocker.AsyncMock(AsyncDriver), "neo4j", 2, 0.1
24+
)
1625

1726

1827
@pytest.mark.asyncio
1928
async def test_execute(database_connection):
20-
query = Query("MATCH (n) RETURN n LIMIT $limit", {"limit": 2})
21-
records = [
22-
{"n": {"name": "foo"}},
23-
{"n": {"name": "bar"}},
24-
]
25-
database_connection.driver.execute_query.return_value.records = records
26-
result = await database_connection.execute(query, log_result=True)
27-
assert_that(result, equal_to(records))
29+
database_connection.driver.execute_query.return_value.records = SOME_RECORDS
30+
result = await database_connection.execute(A_QUERY, log_result=True)
31+
assert_that(result, equal_to(SOME_RECORDS))
2832
database_connection.driver.execute_query.assert_called_once_with(
29-
query.query_statement,
30-
query.parameters,
33+
A_QUERY.query_statement,
34+
A_QUERY.parameters,
3135
database_="neo4j",
3236
routing_=RoutingControl.WRITE,
3337
)
3438

3539

40+
@pytest.mark.asyncio
41+
async def test_execute_fail_and_then_succeed(database_connection, mocker):
42+
database_connection.acquire_driver = mocker.Mock(
43+
wraps=database_connection.acquire_driver
44+
)
45+
database_connection.driver.execute_query.side_effect = [
46+
TransientError("Failed to execute query"),
47+
SOME_RECORDS,
48+
]
49+
await database_connection.execute(A_QUERY)
50+
assert_that(database_connection.acquire_driver.call_count, equal_to(2))
51+
52+
53+
@pytest.mark.asyncio
54+
async def test_execute_fail_and_then_fail(database_connection, mocker):
55+
def driver_factory():
56+
driver = mocker.AsyncMock(AsyncDriver)
57+
driver.execute_query.side_effect = TransientError("Failed to execute query")
58+
return driver
59+
60+
database_connection.driver_factory = driver_factory
61+
with pytest.raises(TransientError):
62+
await database_connection.execute(A_QUERY)
63+
64+
3665
@pytest.mark.asyncio
3766
async def test_session(database_connection):
3867
session = database_connection.session()

0 commit comments

Comments
 (0)