Skip to content

Commit 4e62859

Browse files
ds-filipknefelFilip Knefelrbiseck3
authored
feat: Neo4j Destination Connector (#212)
Implements Neo4j destination connector. Connector takes Unstructured Elements and builds a lexical graph representing relationships between them. It consists of the following entities. Nodes - Document - represents the source file elements come - UnstructuredElement - represents the Unstructured Element prior to - Chunk - represents the Unstructured Element post chunking Edges (Relationships) - UnstructredElement/Chunk - PART_OF_DOCUMENT -> Document - relationship of belonging to the source file - UnstructuredElement - PART_OF_CHUNK -> Chunk - relationship between origin elements making up a chunk - UnstructuredElement - NEXT_ELEMENT -> UnstructuredElement - order of occurrence in the document - Chunk - NEXT_ELEMENT -> Chunk - order of occurrence in document --------- Co-authored-by: Filip Knefel <[email protected]> Co-authored-by: Roman Isecke <[email protected]>
1 parent 7c0b03f commit 4e62859

File tree

10 files changed

+647
-2
lines changed

10 files changed

+647
-2
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
## 0.3.9-dev2
1+
## 0.3.9-dev3
22

33
### Enhancements
44

55
* **Support ndjson files in stagers**
6+
* **Add Neo4j destination connector**
67

78
### Fixes
89

requirements/connectors/neo4j.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
neo4j
2+
cymple

requirements/connectors/neo4j.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# This file was autogenerated by uv via the following command:
2+
# uv pip compile ./connectors/neo4j.in --output-file ./connectors/neo4j.txt --no-strip-extras --python-version 3.9
3+
cymple==0.11.0
4+
# via -r ./connectors/neo4j.in
5+
neo4j==5.25.0
6+
# via -r ./connectors/neo4j.in
7+
pytz==2024.2
8+
# via neo4j

requirements/test.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
pytest
44
pytest-cov
55
pytest-mock
6+
pytest-check
7+
unstructured
68
pytest-asyncio
79
pytest_tagging
810
pytest-json-report

requirements/test.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,16 @@ pytest==8.3.4
124124
# via
125125
# -r test.in
126126
# pytest-asyncio
127+
# pytest-check
127128
# pytest-cov
128129
# pytest-json-report
129130
# pytest-metadata
130131
# pytest-mock
131132
# pytest-tagging
132133
pytest-asyncio==0.25.0
133134
# via -r test.in
135+
pytest-check==2.4.1
136+
# via -r test.in
134137
pytest-cov==6.0.0
135138
# via -r test.in
136139
pytest-json-report==1.5.0

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def load_requirements(file: Union[str, Path]) -> List[str]:
108108
"lancedb": load_requirements("requirements/connectors/lancedb.in"),
109109
"milvus": load_requirements("requirements/connectors/milvus.in"),
110110
"mongodb": load_requirements("requirements/connectors/mongodb.in"),
111+
"neo4j": load_requirements("requirements/connectors/neo4j.in"),
111112
"notion": load_requirements("requirements/connectors/notion.in"),
112113
"onedrive": load_requirements("requirements/connectors/onedrive.in"),
113114
"opensearch": load_requirements("requirements/connectors/opensearch.in"),
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
import json
2+
import time
3+
import uuid
4+
from datetime import datetime
5+
from pathlib import Path
6+
7+
import pytest
8+
from neo4j import AsyncGraphDatabase, Driver, GraphDatabase
9+
from neo4j.exceptions import ServiceUnavailable
10+
from pytest_check import check
11+
12+
from test.integration.connectors.utils.constants import DESTINATION_TAG
13+
from test.integration.connectors.utils.docker import container_context
14+
from unstructured_ingest.error import DestinationConnectionError
15+
from unstructured_ingest.utils.chunking import elements_from_base64_gzipped_json
16+
from unstructured_ingest.v2.interfaces.file_data import (
17+
FileData,
18+
FileDataSourceMetadata,
19+
SourceIdentifiers,
20+
)
21+
from unstructured_ingest.v2.processes.connectors.neo4j import (
22+
CONNECTOR_TYPE,
23+
Label,
24+
Neo4jAccessConfig,
25+
Neo4jConnectionConfig,
26+
Neo4jUploader,
27+
Neo4jUploaderConfig,
28+
Neo4jUploadStager,
29+
Relationship,
30+
)
31+
32+
USERNAME = "neo4j"
33+
PASSWORD = "password"
34+
URI = "neo4j://localhost:7687"
35+
DATABASE = "neo4j"
36+
37+
EXPECTED_DOCUMENT_COUNT = 1
38+
39+
40+
# NOTE: Precheck tests are read-only so we utilize the same container for all tests.
41+
# If new tests require clean neo4j container, this fixture's scope should be adjusted.
42+
@pytest.fixture(autouse=True, scope="module")
43+
def _neo4j_server():
44+
with container_context(
45+
image="neo4j:latest", environment={"NEO4J_AUTH": "neo4j/password"}, ports={"7687": "7687"}
46+
):
47+
driver = GraphDatabase.driver(uri=URI, auth=(USERNAME, PASSWORD))
48+
wait_for_connection(driver)
49+
driver.close()
50+
yield
51+
52+
53+
@pytest.mark.asyncio
54+
@pytest.mark.tags(DESTINATION_TAG, CONNECTOR_TYPE)
55+
async def test_neo4j_destination(upload_file: Path, tmp_path: Path):
56+
stager = Neo4jUploadStager()
57+
uploader = Neo4jUploader(
58+
connection_config=Neo4jConnectionConfig(
59+
access_config=Neo4jAccessConfig(password=PASSWORD), # type: ignore
60+
username=USERNAME,
61+
uri=URI,
62+
database=DATABASE,
63+
),
64+
upload_config=Neo4jUploaderConfig(),
65+
)
66+
file_data = FileData(
67+
identifier="mock-file-data",
68+
connector_type="neo4j",
69+
source_identifiers=SourceIdentifiers(
70+
filename=upload_file.name,
71+
fullpath=upload_file.name,
72+
),
73+
metadata=FileDataSourceMetadata(
74+
date_created=str(datetime(2022, 1, 1).timestamp()),
75+
date_modified=str(datetime(2022, 1, 2).timestamp()),
76+
),
77+
)
78+
staged_filepath = stager.run(
79+
upload_file,
80+
file_data=file_data,
81+
output_dir=tmp_path,
82+
output_filename=upload_file.name,
83+
)
84+
85+
await uploader.run_async(staged_filepath, file_data)
86+
await validate_uploaded_graph(upload_file)
87+
88+
modified_upload_file = tmp_path / f"modified-{upload_file.name}"
89+
with open(upload_file) as file:
90+
elements = json.load(file)
91+
for element in elements:
92+
element["element_id"] = str(uuid.uuid4())
93+
94+
with open(modified_upload_file, "w") as file:
95+
json.dump(elements, file, indent=4)
96+
97+
staged_filepath = stager.run(
98+
modified_upload_file,
99+
file_data=file_data,
100+
output_dir=tmp_path,
101+
output_filename=modified_upload_file.name,
102+
)
103+
await uploader.run_async(staged_filepath, file_data)
104+
await validate_uploaded_graph(modified_upload_file)
105+
106+
107+
@pytest.mark.tags(DESTINATION_TAG, CONNECTOR_TYPE)
108+
class TestPrecheck:
109+
@pytest.fixture
110+
def configured_uploader(self) -> Neo4jUploader:
111+
return Neo4jUploader(
112+
connection_config=Neo4jConnectionConfig(
113+
access_config=Neo4jAccessConfig(password=PASSWORD), # type: ignore
114+
username=USERNAME,
115+
uri=URI,
116+
database=DATABASE,
117+
),
118+
upload_config=Neo4jUploaderConfig(),
119+
)
120+
121+
def test_succeeds(self, configured_uploader: Neo4jUploader):
122+
configured_uploader.precheck()
123+
124+
def test_fails_on_invalid_password(self, configured_uploader: Neo4jUploader):
125+
configured_uploader.connection_config.access_config.get_secret_value().password = (
126+
"invalid-password"
127+
)
128+
with pytest.raises(
129+
DestinationConnectionError,
130+
match="{code: Neo.ClientError.Security.Unauthorized}",
131+
):
132+
configured_uploader.precheck()
133+
134+
def test_fails_on_invalid_username(self, configured_uploader: Neo4jUploader):
135+
configured_uploader.connection_config.username = "invalid-username"
136+
with pytest.raises(
137+
DestinationConnectionError, match="{code: Neo.ClientError.Security.Unauthorized}"
138+
):
139+
configured_uploader.precheck()
140+
141+
@pytest.mark.parametrize(
142+
("uri", "expected_error_msg"),
143+
[
144+
("neo4j://localhst:7687", "Cannot resolve address"),
145+
("neo4j://localhost:7777", "Unable to retrieve routing information"),
146+
],
147+
)
148+
def test_fails_on_invalid_uri(
149+
self, configured_uploader: Neo4jUploader, uri: str, expected_error_msg: str
150+
):
151+
configured_uploader.connection_config.uri = uri
152+
with pytest.raises(DestinationConnectionError, match=expected_error_msg):
153+
configured_uploader.precheck()
154+
155+
def test_fails_on_invalid_database(self, configured_uploader: Neo4jUploader):
156+
configured_uploader.connection_config.database = "invalid-database"
157+
with pytest.raises(
158+
DestinationConnectionError, match="{code: Neo.ClientError.Database.DatabaseNotFound}"
159+
):
160+
configured_uploader.precheck()
161+
162+
163+
def wait_for_connection(driver: Driver, retries: int = 10, delay_seconds: int = 2):
164+
attempts = 0
165+
while attempts < retries:
166+
try:
167+
driver.verify_connectivity()
168+
return
169+
except ServiceUnavailable:
170+
time.sleep(delay_seconds)
171+
attempts += 1
172+
173+
pytest.fail("Failed to connect with Neo4j server.")
174+
175+
176+
async def validate_uploaded_graph(upload_file: Path):
177+
with open(upload_file) as file:
178+
elements = json.load(file)
179+
180+
for element in elements:
181+
if "orig_elements" in element["metadata"]:
182+
element["metadata"]["orig_elements"] = elements_from_base64_gzipped_json(
183+
element["metadata"]["orig_elements"]
184+
)
185+
else:
186+
element["metadata"]["orig_elements"] = []
187+
188+
expected_chunks_count = len(elements)
189+
expected_element_count = len(
190+
{
191+
origin_element["element_id"]
192+
for chunk in elements
193+
for origin_element in chunk["metadata"]["orig_elements"]
194+
}
195+
)
196+
expected_nodes_count = expected_chunks_count + expected_element_count + EXPECTED_DOCUMENT_COUNT
197+
198+
driver = AsyncGraphDatabase.driver(uri=URI, auth=(USERNAME, PASSWORD))
199+
try:
200+
nodes_count = len((await driver.execute_query("MATCH (n) RETURN n"))[0])
201+
chunk_nodes_count = len(
202+
(await driver.execute_query(f"MATCH (n: {Label.CHUNK}) RETURN n"))[0]
203+
)
204+
document_nodes_count = len(
205+
(await driver.execute_query(f"MATCH (n: {Label.DOCUMENT}) RETURN n"))[0]
206+
)
207+
element_nodes_count = len(
208+
(await driver.execute_query(f"MATCH (n: {Label.UNSTRUCTURED_ELEMENT}) RETURN n"))[0]
209+
)
210+
with check:
211+
assert nodes_count == expected_nodes_count
212+
with check:
213+
assert document_nodes_count == EXPECTED_DOCUMENT_COUNT
214+
with check:
215+
assert chunk_nodes_count == expected_chunks_count
216+
with check:
217+
assert element_nodes_count == expected_element_count
218+
219+
records, _, _ = await driver.execute_query(
220+
f"MATCH ()-[r:{Relationship.PART_OF_DOCUMENT}]->(:{Label.DOCUMENT}) RETURN r"
221+
)
222+
part_of_document_count = len(records)
223+
224+
records, _, _ = await driver.execute_query(
225+
f"MATCH (:{Label.CHUNK})-[r:{Relationship.NEXT_CHUNK}]->(:{Label.CHUNK}) RETURN r"
226+
)
227+
next_chunk_count = len(records)
228+
229+
if not check.any_failures():
230+
with check:
231+
assert part_of_document_count == expected_chunks_count + expected_element_count
232+
with check:
233+
assert next_chunk_count == expected_chunks_count - 1
234+
235+
finally:
236+
await driver.close()

unstructured_ingest/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.3.9-dev2" # pragma: no cover
1+
__version__ = "0.3.9-dev3" # pragma: no cover

unstructured_ingest/utils/chunking.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import base64
12
import hashlib
3+
import json
4+
import zlib
25
from itertools import groupby
36

47

@@ -43,3 +46,11 @@ def assign_and_map_hash_ids(elements: list[dict]) -> list[dict]:
4346
e["metadata"]["parent_id"] = old_to_new_mapping[parent_id]
4447

4548
return elements
49+
50+
51+
def elements_from_base64_gzipped_json(raw_s: str) -> list[dict]:
52+
decoded_b64_bytes = base64.b64decode(raw_s)
53+
elements_json_bytes = zlib.decompress(decoded_b64_bytes)
54+
elements_json_str = elements_json_bytes.decode("utf-8")
55+
element_dicts = json.loads(elements_json_str)
56+
return element_dicts

0 commit comments

Comments
 (0)