Skip to content

Commit 15e8c4b

Browse files
authored
Merge pull request #16 from nodestream-proj/feature/dynamically-retrieve-credentials
Dynamically Retrieve Credentials
2 parents 65baf4d + 861ec40 commit 15e8c4b

File tree

3 files changed

+265
-213
lines changed

3 files changed

+265
-213
lines changed

nodestream_plugin_neo4j/neo4j_database.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,51 @@
11
from logging import getLogger
2-
from typing import Iterable
2+
from typing import Awaitable, Iterable, Tuple, Union
33

44
from neo4j import AsyncDriver, AsyncGraphDatabase, AsyncSession, Record, RoutingControl
5+
from neo4j.auth_management import AsyncAuthManagers
6+
from nodestream.file_io import LazyLoadedArgument
57

68
from .query import Query
79

810

11+
def auth_provider_factory(
12+
username: Union[str, LazyLoadedArgument],
13+
password: Union[str, LazyLoadedArgument],
14+
) -> Awaitable[Tuple[str, str]]:
15+
logger = getLogger(__name__)
16+
17+
async def auth_provider():
18+
logger.info("Fetching new neo4j credentials")
19+
20+
if isinstance(username, LazyLoadedArgument):
21+
logger.debug("Fetching username since value is lazy loaded")
22+
current_username = username.get_value()
23+
else:
24+
current_username = username
25+
26+
if isinstance(password, LazyLoadedArgument):
27+
logger.debug("Fetching password since value is lazy loaded")
28+
current_password = password.get_value()
29+
else:
30+
current_password = password
31+
32+
return current_username, current_password
33+
34+
return auth_provider
35+
36+
937
class Neo4jDatabaseConnection:
1038
@classmethod
1139
def from_configuration(
1240
cls,
1341
uri: str,
14-
username: str,
15-
password: str,
42+
username: Union[str, LazyLoadedArgument],
43+
password: Union[str, LazyLoadedArgument],
1644
database_name: str = "neo4j",
1745
**driver_kwargs
1846
):
19-
driver = AsyncGraphDatabase.driver(
20-
uri, auth=(username, password), **driver_kwargs
21-
)
47+
auth = AsyncAuthManagers.basic(auth_provider_factory(username, password))
48+
driver = AsyncGraphDatabase.driver(uri, auth=auth, **driver_kwargs)
2249
return cls(driver, database_name)
2350

2451
def __init__(self, driver: AsyncDriver, database_name: str) -> None:

0 commit comments

Comments
 (0)