Skip to content

Commit 27e9565

Browse files
committed
do retries and add tests
1 parent 35799a0 commit 27e9565

File tree

2 files changed

+60
-19
lines changed

2 files changed

+60
-19
lines changed

nodestream_plugin_neo4j/neo4j_database.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
1+
import asyncio
12
from logging import getLogger
23
from typing import Awaitable, Iterable, Tuple, Union
34

45
from neo4j import AsyncDriver, AsyncGraphDatabase, AsyncSession, Record, RoutingControl
56
from neo4j.auth_management import AsyncAuthManagers
67
from neo4j.exceptions import (
78
AuthError,
8-
TransientError,
99
ServiceUnavailable,
1010
SessionExpired,
11+
TransientError,
1112
)
1213
from nodestream.file_io import LazyLoadedArgument
1314

1415
from .query import Query
1516

16-
1717
RETRYABLE_EXCEPTIONS = (TransientError, ServiceUnavailable, SessionExpired, AuthError)
1818

1919

@@ -52,21 +52,27 @@ def from_configuration(
5252
password: Union[str, LazyLoadedArgument],
5353
database_name: str = "neo4j",
5454
max_retry_attempts: int = 3,
55+
retry_factor: int = 1,
5556
**driver_kwargs,
5657
):
5758
def driver_factory():
5859
auth = AsyncAuthManagers.basic(auth_provider_factory(username, password))
5960
return AsyncGraphDatabase.driver(uri, auth=auth, **driver_kwargs)
6061

61-
return cls(driver_factory, database_name, max_retry_attempts)
62+
return cls(driver_factory, database_name, max_retry_attempts, retry_factor)
6263

6364
def __init__(
64-
self, driver_factory, database_name: str, max_retry_attempts: int
65+
self,
66+
driver_factory,
67+
database_name: str,
68+
max_retry_attempts: int = 3,
69+
retry_factor: float = 1,
6570
) -> None:
6671
self.driver_factory = driver_factory
6772
self.database_name = database_name
6873
self.logger = getLogger(self.__class__.__name__)
6974
self.max_retry_attempts = max_retry_attempts
75+
self.retry_factor = retry_factor
7076
self._driver = None
7177

7278
def acquire_driver(self) -> AsyncDriver:
@@ -97,7 +103,10 @@ def log_record(self, record: Record):
97103
)
98104

99105
async def _execute_query(
100-
self, query: Query, log_result: bool = False, routing_=RoutingControl.WRITE
106+
self,
107+
query: Query,
108+
log_result: bool = False,
109+
routing_=RoutingControl.WRITE,
101110
) -> Record:
102111
result = await self.driver.execute_query(
103112
query.query_statement,
@@ -113,7 +122,10 @@ async def _execute_query(
113122
return records
114123

115124
async def execute(
116-
self, query: Query, log_result: bool = False, routing_=RoutingControl.WRITE
125+
self,
126+
query: Query,
127+
log_result: bool = False,
128+
routing_=RoutingControl.WRITE,
117129
) -> Iterable[Record]:
118130
self.log_query_start(query)
119131
attempts = 0
@@ -126,7 +138,7 @@ async def execute(
126138
f"Error executing query, retrying. Attempt {attempts + 1}",
127139
exc_info=e,
128140
)
141+
await asyncio.sleep(self.retry_factor * attempts)
129142
self.acquire_driver()
130143
if attempts >= self.max_retry_attempts:
131-
message = f"Failed to execute after {self.max_retry_attempts} tries"
132-
raise Exception(message) from e
144+
raise e

tests/unit/test_neo4j_database.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
from hamcrest import assert_that, equal_to
33
from neo4j import AsyncDriver, RoutingControl
4+
from neo4j.exceptions import TransientError
45
from nodestream.file_io import LazyLoadedArgument
56

67
from nodestream_plugin_neo4j.neo4j_database import (
@@ -9,30 +10,58 @@
910
)
1011
from nodestream_plugin_neo4j.query import Query
1112

13+
A_QUERY = Query("MATCH (n) RETURN n LIMIT $limit", {"limit": 2})
14+
SOME_RECORDS = [
15+
{"n": {"name": "foo"}},
16+
{"n": {"name": "bar"}},
17+
]
18+
1219

1320
@pytest.fixture
1421
def database_connection(mocker):
15-
return Neo4jDatabaseConnection(mocker.AsyncMock(AsyncDriver), "neo4j", 3)
22+
return Neo4jDatabaseConnection(
23+
lambda: mocker.AsyncMock(AsyncDriver), "neo4j", 2, 0.1
24+
)
1625

1726

1827
@pytest.mark.asyncio
1928
async def test_execute(database_connection):
20-
query = Query("MATCH (n) RETURN n LIMIT $limit", {"limit": 2})
21-
records = [
22-
{"n": {"name": "foo"}},
23-
{"n": {"name": "bar"}},
24-
]
25-
database_connection.driver.execute_query.return_value.records = records
26-
result = await database_connection.execute(query, log_result=True)
27-
assert_that(result, equal_to(records))
29+
database_connection.driver.execute_query.return_value.records = SOME_RECORDS
30+
result = await database_connection.execute(A_QUERY, log_result=True)
31+
assert_that(result, equal_to(SOME_RECORDS))
2832
database_connection.driver.execute_query.assert_called_once_with(
29-
query.query_statement,
30-
query.parameters,
33+
A_QUERY.query_statement,
34+
A_QUERY.parameters,
3135
database_="neo4j",
3236
routing_=RoutingControl.WRITE,
3337
)
3438

3539

40+
@pytest.mark.asyncio
41+
async def test_execute_fail_and_then_succeed(database_connection, mocker):
42+
database_connection.acquire_driver = mocker.Mock(
43+
wraps=database_connection.acquire_driver
44+
)
45+
database_connection.driver.execute_query.side_effect = [
46+
TransientError("Failed to execute query"),
47+
SOME_RECORDS,
48+
]
49+
await database_connection.execute(A_QUERY)
50+
assert_that(database_connection.acquire_driver.call_count, equal_to(2))
51+
52+
53+
@pytest.mark.asyncio
54+
async def test_execute_fail_and_then_fail(database_connection, mocker):
55+
def driver_factory():
56+
driver = mocker.AsyncMock(AsyncDriver)
57+
driver.execute_query.side_effect = TransientError("Failed to execute query")
58+
return driver
59+
60+
database_connection.driver_factory = driver_factory
61+
with pytest.raises(TransientError):
62+
await database_connection.execute(A_QUERY)
63+
64+
3665
@pytest.mark.asyncio
3766
async def test_session(database_connection):
3867
session = database_connection.session()

0 commit comments

Comments
 (0)