Skip to content

Commit 9d414b4

Browse files
committed
Dynamically Retrieve Credentials
1 parent 65baf4d commit 9d414b4

File tree

2 files changed

+60
-7
lines changed

2 files changed

+60
-7
lines changed

nodestream_plugin_neo4j/neo4j_database.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,52 @@
11
from logging import getLogger
2-
from typing import Iterable
2+
from typing import Iterable, Union, Tuple, Awaitable
3+
4+
from nodestream.file_io import LazyLoadedArgument
35

46
from neo4j import AsyncDriver, AsyncGraphDatabase, AsyncSession, Record, RoutingControl
7+
from neo4j.auth_management import AsyncAuthManagers
58

69
from .query import Query
710

811

12+
def auth_provider_factory(
13+
username: Union[str, LazyLoadedArgument],
14+
password: Union[str, LazyLoadedArgument],
15+
) -> Awaitable[Tuple[str, str]]:
16+
logger = getLogger(__name__)
17+
18+
async def auth_provider():
19+
logger.info("Fetching new neo4j credentials")
20+
21+
if isinstance(username, LazyLoadedArgument):
22+
logger.debug("Fetching username since value is lazy loaded")
23+
current_username = username.get_value()
24+
else:
25+
current_username = username
26+
27+
if isinstance(password, LazyLoadedArgument):
28+
logger.debug("Fetching password since value is lazy loaded")
29+
current_password = password.get_value()
30+
else:
31+
current_password = password
32+
33+
return current_username, current_password
34+
35+
return auth_provider
36+
37+
938
class Neo4jDatabaseConnection:
1039
@classmethod
1140
def from_configuration(
1241
cls,
1342
uri: str,
14-
username: str,
15-
password: str,
43+
username: Union[str, LazyLoadedArgument],
44+
password: Union[str, LazyLoadedArgument],
1645
database_name: str = "neo4j",
1746
**driver_kwargs
1847
):
19-
driver = AsyncGraphDatabase.driver(
20-
uri, auth=(username, password), **driver_kwargs
21-
)
48+
auth = AsyncAuthManagers.basic(auth_provider_factory(username, password))
49+
driver = AsyncGraphDatabase.driver(uri, auth=auth, **driver_kwargs)
2250
return cls(driver, database_name)
2351

2452
def __init__(self, driver: AsyncDriver, database_name: str) -> None:

tests/unit/test_neo4j_database.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
from hamcrest import assert_that, equal_to
33
from neo4j import AsyncDriver, RoutingControl
44

5-
from nodestream_plugin_neo4j.neo4j_database import Neo4jDatabaseConnection
5+
from nodestream.file_io import LazyLoadedArgument
6+
7+
from nodestream_plugin_neo4j.neo4j_database import (
8+
Neo4jDatabaseConnection,
9+
auth_provider_factory,
10+
)
611
from nodestream_plugin_neo4j.query import Query
712

813

@@ -34,3 +39,23 @@ async def test_session(database_connection):
3439
session = database_connection.session()
3540
assert_that(session, equal_to(database_connection.driver.session.return_value))
3641
database_connection.driver.session.assert_called_once_with(database="neo4j")
42+
43+
44+
@pytest.mark.asyncio
45+
async def test_auth_provider_factory_with_dynamic_values(mocker):
46+
username = mocker.Mock(LazyLoadedArgument)
47+
password = mocker.Mock(LazyLoadedArgument)
48+
provider = auth_provider_factory(username, password)
49+
retrieved_username, retrieved_password = await provider()
50+
assert_that(retrieved_username, equal_to(username.get_value.return_value))
51+
assert_that(retrieved_password, equal_to(password.get_value.return_value))
52+
53+
54+
@pytest.mark.asyncio
55+
async def test_auth_provider_factory_with_static_values():
56+
username = "neo4j"
57+
password = "password"
58+
provider = auth_provider_factory(username, password)
59+
retrieved_username, retrieved_password = await provider()
60+
assert_that(retrieved_username, equal_to(username))
61+
assert_that(retrieved_password, equal_to(password))

0 commit comments

Comments
 (0)