|
1 | 1 | from logging import getLogger
|
2 |
| -from typing import Iterable |
| 2 | +from typing import Awaitable, Iterable, Tuple, Union |
3 | 3 |
|
4 | 4 | from neo4j import AsyncDriver, AsyncGraphDatabase, AsyncSession, Record, RoutingControl
|
| 5 | +from neo4j.auth_management import AsyncAuthManagers |
| 6 | +from nodestream.file_io import LazyLoadedArgument |
5 | 7 |
|
6 | 8 | from .query import Query
|
7 | 9 |
|
8 | 10 |
|
| 11 | +def auth_provider_factory( |
| 12 | + username: Union[str, LazyLoadedArgument], |
| 13 | + password: Union[str, LazyLoadedArgument], |
| 14 | +) -> Awaitable[Tuple[str, str]]: |
| 15 | + logger = getLogger(__name__) |
| 16 | + |
| 17 | + async def auth_provider(): |
| 18 | + logger.info("Fetching new neo4j credentials") |
| 19 | + |
| 20 | + if isinstance(username, LazyLoadedArgument): |
| 21 | + logger.debug("Fetching username since value is lazy loaded") |
| 22 | + current_username = username.get_value() |
| 23 | + else: |
| 24 | + current_username = username |
| 25 | + |
| 26 | + if isinstance(password, LazyLoadedArgument): |
| 27 | + logger.debug("Fetching password since value is lazy loaded") |
| 28 | + current_password = password.get_value() |
| 29 | + else: |
| 30 | + current_password = password |
| 31 | + |
| 32 | + return current_username, current_password |
| 33 | + |
| 34 | + return auth_provider |
| 35 | + |
| 36 | + |
9 | 37 | class Neo4jDatabaseConnection:
|
10 | 38 | @classmethod
|
11 | 39 | def from_configuration(
|
12 | 40 | cls,
|
13 | 41 | uri: str,
|
14 |
| - username: str, |
15 |
| - password: str, |
| 42 | + username: Union[str, LazyLoadedArgument], |
| 43 | + password: Union[str, LazyLoadedArgument], |
16 | 44 | database_name: str = "neo4j",
|
17 | 45 | **driver_kwargs
|
18 | 46 | ):
|
19 |
| - driver = AsyncGraphDatabase.driver( |
20 |
| - uri, auth=(username, password), **driver_kwargs |
21 |
| - ) |
| 47 | + auth = AsyncAuthManagers.basic(auth_provider_factory(username, password)) |
| 48 | + driver = AsyncGraphDatabase.driver(uri, auth=auth, **driver_kwargs) |
22 | 49 | return cls(driver, database_name)
|
23 | 50 |
|
24 | 51 | def __init__(self, driver: AsyncDriver, database_name: str) -> None:
|
|
0 commit comments