Skip to content

Commit 4edcd1b

Browse files
jexpFilip Knefelahmetmeleq
authored
Proposal for Neo4j Uploader Improvments (#357)
* proposed changes to improve neo4j operations - use neo4j-rust-ext instead of plain neo4j driver for 10x perf improvement - always match on a single node label (equivalent to the constraint), never blank matches - group relationships by type, source- and target-type - increase batch size - use vector property procedure to set fp32 instead of p64 - method to select the main label for a node - TODO: create vector index would need information from the embedder (dimension) and similarity function (from config) - Set extra labels * Fixed one missing label.value, added created nodes/rels to log output * Set default values for username and database * Modify main label logic, deduplicate data from node and edge Implement main label getting logic on the Node. Validate that Node has at least one label to ensure there's always a main label. Remove data about nodes from the edge, refer to Node objects directly in the Edge. * Add index creation with 'cosine' similarity function * make creation optional, resolve conflict * changelog and version * tidy --------- Co-authored-by: Filip Knefel <[email protected]> Co-authored-by: Ahmet Melek <[email protected]> Co-authored-by: Ahmet Melek <[email protected]>
1 parent 442dbfa commit 4edcd1b

File tree

5 files changed

+135
-48
lines changed

5 files changed

+135
-48
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
## 0.5.3-dev1
1+
## 0.5.3-dev2
22

33
### Enhancements
44

5+
* **Improvements on Neo4J uploader, and ability to create a vector index**
56
* **Optimize embedder code** - Move duplicate code to base interface, exit early if no elements have text.
67

78
### Fixes

requirements/connectors/neo4j.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
neo4j
1+
neo4j-rust-ext
22
cymple
33
networkx

requirements/connectors/neo4j.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
# uv pip compile ./connectors/neo4j.in --output-file ./connectors/neo4j.txt --no-strip-extras --python-version 3.9
33
cymple==0.12.0
44
# via -r ./connectors/neo4j.in
5-
neo4j==5.28.1
5+
neo4j-rust-ext==5.27.0.0
66
# via -r ./connectors/neo4j.in
77
networkx==3.2.1
88
# via -r ./connectors/neo4j.in
9-
pytz==2025.1
9+
pytz==2024.2
1010
# via neo4j

unstructured_ingest/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.5.3-dev1" # pragma: no cover
1+
__version__ = "0.5.3-dev2" # pragma: no cover

unstructured_ingest/v2/processes/connectors/neo4j.py

Lines changed: 129 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from dataclasses import dataclass
99
from enum import Enum
1010
from pathlib import Path
11-
from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional
11+
from typing import TYPE_CHECKING, Any, AsyncGenerator, Literal, Optional
1212

13-
from pydantic import BaseModel, ConfigDict, Field, Secret
13+
from pydantic import BaseModel, ConfigDict, Field, Secret, field_validator
1414

1515
from unstructured_ingest.error import DestinationConnectionError
1616
from unstructured_ingest.logger import logger
@@ -30,6 +30,8 @@
3030
DestinationRegistryEntry,
3131
)
3232

33+
SimilarityFunction = Literal["cosine"]
34+
3335
if TYPE_CHECKING:
3436
from neo4j import AsyncDriver, Auth
3537
from networkx import Graph, MultiDiGraph
@@ -44,9 +46,9 @@ class Neo4jAccessConfig(AccessConfig):
4446
class Neo4jConnectionConfig(ConnectionConfig):
4547
access_config: Secret[Neo4jAccessConfig]
4648
connector_type: str = Field(default=CONNECTOR_TYPE, init=False)
47-
username: str
49+
username: str = Field(default="neo4j")
4850
uri: str = Field(description="Neo4j Connection URI <scheme>://<host>:<port>")
49-
database: str = Field(description="Name of the target database")
51+
database: str = Field(default="neo4j", description="Name of the target database")
5052

5153
@requires_dependencies(["neo4j"], extras="neo4j")
5254
@asynccontextmanager
@@ -186,8 +188,8 @@ def from_nx(cls, nx_graph: "MultiDiGraph") -> _GraphData:
186188
nodes = list(nx_graph.nodes())
187189
edges = [
188190
_Edge(
189-
source_id=u.id_,
190-
destination_id=v.id_,
191+
source=u,
192+
destination=v,
191193
relationship=Relationship(data_dict["relationship"]),
192194
)
193195
for u, v, data_dict in nx_graph.edges(data=True)
@@ -198,19 +200,30 @@ def from_nx(cls, nx_graph: "MultiDiGraph") -> _GraphData:
198200
class _Node(BaseModel):
199201
model_config = ConfigDict()
200202

201-
id_: str = Field(default_factory=lambda: str(uuid.uuid4()))
202-
labels: list[Label] = Field(default_factory=list)
203+
labels: list[Label]
203204
properties: dict = Field(default_factory=dict)
205+
id_: str = Field(default_factory=lambda: str(uuid.uuid4()))
204206

205207
def __hash__(self):
206208
return hash(self.id_)
207209

210+
@property
211+
def main_label(self) -> Label:
212+
return self.labels[0]
213+
214+
@classmethod
215+
@field_validator("labels", mode="after")
216+
def require_at_least_one_label(cls, value: list[Label]) -> list[Label]:
217+
if not value:
218+
raise ValueError("Node must have at least one label.")
219+
return value
220+
208221

209222
class _Edge(BaseModel):
210223
model_config = ConfigDict()
211224

212-
source_id: str
213-
destination_id: str
225+
source: _Node
226+
destination: _Node
214227
relationship: Relationship
215228

216229

@@ -229,7 +242,14 @@ class Relationship(Enum):
229242

230243
class Neo4jUploaderConfig(UploaderConfig):
231244
batch_size: int = Field(
232-
default=100, description="Maximal number of nodes/relationships created per transaction."
245+
default=1000, description="Maximal number of nodes/relationships created per transaction."
246+
)
247+
similarity_function: SimilarityFunction = Field(
248+
default="cosine",
249+
description="Vector similarity function used to create index on Chunk nodes",
250+
)
251+
create_destination: bool = Field(
252+
default=True, description="Create destination if it does not exist"
233253
)
234254

235255

@@ -257,6 +277,13 @@ async def run_async(self, path: Path, file_data: FileData, **kwargs) -> None: #
257277
graph_data = _GraphData.model_validate(staged_data)
258278
async with self.connection_config.get_client() as client:
259279
await self._create_uniqueness_constraints(client)
280+
embedding_dimensions = self._get_embedding_dimensions(graph_data)
281+
if embedding_dimensions and self.upload_config.create_destination:
282+
await self._create_vector_index(
283+
client,
284+
dimensions=embedding_dimensions,
285+
similarity_function=self.upload_config.similarity_function,
286+
)
260287
await self._delete_old_data_if_exists(file_data, client=client)
261288
await self._merge_graph(graph_data=graph_data, client=client)
262289

@@ -274,13 +301,33 @@ async def _create_uniqueness_constraints(self, client: AsyncDriver) -> None:
274301
"""
275302
)
276303

304+
async def _create_vector_index(
305+
self, client: AsyncDriver, dimensions: int, similarity_function: SimilarityFunction
306+
) -> None:
307+
label = Label.CHUNK
308+
logger.info(
309+
f"Creating index on nodes labeled '{label.value}' if it does not already exist."
310+
)
311+
index_name = f"{label.value.lower()}_vector"
312+
await client.execute_query(
313+
f"""
314+
CREATE VECTOR INDEX {index_name} IF NOT EXISTS
315+
FOR (n:{label.value}) ON n.embedding
316+
OPTIONS {{indexConfig: {{
317+
`vector.similarity_function`: '{similarity_function}',
318+
`vector.dimensions`: {dimensions}}}
319+
}}
320+
"""
321+
)
322+
277323
async def _delete_old_data_if_exists(self, file_data: FileData, client: AsyncDriver) -> None:
278324
logger.info(f"Deleting old data for the record '{file_data.identifier}' (if present).")
279325
_, summary, _ = await client.execute_query(
280326
f"""
281-
MATCH (n: {Label.DOCUMENT.value} {{id: $identifier}})
282-
MATCH (n)--(m: {Label.CHUNK.value}|{Label.UNSTRUCTURED_ELEMENT.value})
283-
DETACH DELETE m""",
327+
MATCH (n: `{Label.DOCUMENT.value}` {{id: $identifier}})
328+
MATCH (n)--(m: `{Label.CHUNK.value}`|`{Label.UNSTRUCTURED_ELEMENT.value}`)
329+
DETACH DELETE m
330+
DETACH DELETE n""",
284331
identifier=file_data.identifier,
285332
)
286333
logger.info(
@@ -289,33 +336,39 @@ async def _delete_old_data_if_exists(self, file_data: FileData, client: AsyncDri
289336
)
290337

291338
async def _merge_graph(self, graph_data: _GraphData, client: AsyncDriver) -> None:
292-
nodes_by_labels: defaultdict[tuple[Label, ...], list[_Node]] = defaultdict(list)
339+
nodes_by_labels: defaultdict[Label, list[_Node]] = defaultdict(list)
293340
for node in graph_data.nodes:
294-
nodes_by_labels[tuple(node.labels)].append(node)
295-
341+
nodes_by_labels[node.main_label].append(node)
296342
logger.info(f"Merging {len(graph_data.nodes)} graph nodes.")
297343
# NOTE: Processed in parallel as there's no overlap between accessed nodes
298344
await self._execute_queries(
299345
[
300-
self._create_nodes_query(nodes_batch, labels)
301-
for labels, nodes in nodes_by_labels.items()
346+
self._create_nodes_query(nodes_batch, label)
347+
for label, nodes in nodes_by_labels.items()
302348
for nodes_batch in batch_generator(nodes, batch_size=self.upload_config.batch_size)
303349
],
304350
client=client,
305351
in_parallel=True,
306352
)
307353
logger.info(f"Finished merging {len(graph_data.nodes)} graph nodes.")
308354

309-
edges_by_relationship: defaultdict[Relationship, list[_Edge]] = defaultdict(list)
355+
edges_by_relationship: defaultdict[tuple[Relationship, Label, Label], list[_Edge]] = (
356+
defaultdict(list)
357+
)
310358
for edge in graph_data.edges:
311-
edges_by_relationship[edge.relationship].append(edge)
359+
key = (edge.relationship, edge.source.main_label, edge.destination.main_label)
360+
edges_by_relationship[key].append(edge)
312361

313362
logger.info(f"Merging {len(graph_data.edges)} graph relationships (edges).")
314363
# NOTE: Processed sequentially to avoid queries locking node access to one another
315364
await self._execute_queries(
316365
[
317-
self._create_edges_query(edges_batch, relationship)
318-
for relationship, edges in edges_by_relationship.items()
366+
self._create_edges_query(edges_batch, relationship, source_label, destination_label)
367+
for (
368+
relationship,
369+
source_label,
370+
destination_label,
371+
), edges in edges_by_relationship.items()
319372
for edges_batch in batch_generator(edges, batch_size=self.upload_config.batch_size)
320373
],
321374
client=client,
@@ -328,53 +381,86 @@ async def _execute_queries(
328381
client: AsyncDriver,
329382
in_parallel: bool = False,
330383
) -> None:
384+
from neo4j import EagerResult
385+
386+
results: list[EagerResult] = []
387+
logger.info(
388+
f"Executing {len(queries_with_parameters)} "
389+
+ f"{'parallel' if in_parallel else 'sequential'} Cypher statements."
390+
)
331391
if in_parallel:
332-
logger.info(f"Executing {len(queries_with_parameters)} queries in parallel.")
333-
await asyncio.gather(
392+
results = await asyncio.gather(
334393
*[
335394
client.execute_query(query, parameters_=parameters)
336395
for query, parameters in queries_with_parameters
337396
]
338397
)
339-
logger.info("Finished executing parallel queries.")
340398
else:
341-
logger.info(f"Executing {len(queries_with_parameters)} queries sequentially.")
342399
for i, (query, parameters) in enumerate(queries_with_parameters):
343-
logger.info(f"Query #{i} started.")
344-
await client.execute_query(query, parameters_=parameters)
345-
logger.info(f"Query #{i} finished.")
346-
logger.info(
347-
f"Finished executing all ({len(queries_with_parameters)}) sequential queries."
348-
)
400+
logger.info(f"Statement #{i} started.")
401+
results.append(await client.execute_query(query, parameters_=parameters))
402+
logger.info(f"Statement #{i} finished.")
403+
nodeCount = sum([res.summary.counters.nodes_created for res in results])
404+
relCount = sum([res.summary.counters.relationships_created for res in results])
405+
logger.info(
406+
f"Finished executing all ({len(queries_with_parameters)}) "
407+
+ f"{'parallel' if in_parallel else 'sequential'} Cypher statements. "
408+
+ f"Created {nodeCount} nodes, {relCount} relationships."
409+
)
349410

350411
@staticmethod
351-
def _create_nodes_query(nodes: list[_Node], labels: tuple[Label, ...]) -> tuple[str, dict]:
352-
labels_string = ", ".join([label.value for label in labels])
353-
logger.info(f"Preparing MERGE query for {len(nodes)} nodes labeled '{labels_string}'.")
412+
def _create_nodes_query(nodes: list[_Node], label: Label) -> tuple[str, dict]:
413+
logger.info(f"Preparing MERGE query for {len(nodes)} nodes labeled '{label}'.")
354414
query_string = f"""
355415
UNWIND $nodes AS node
356-
MERGE (n: {labels_string} {{id: node.id}})
416+
MERGE (n: `{label.value}` {{id: node.id}})
357417
SET n += node.properties
418+
SET n:$(node.labels)
419+
WITH * WHERE node.vector IS NOT NULL
420+
CALL db.create.setNodeVectorProperty(n, 'embedding', node.vector)
358421
"""
359-
parameters = {"nodes": [{"id": node.id_, "properties": node.properties} for node in nodes]}
422+
parameters = {
423+
"nodes": [
424+
{
425+
"id": node.id_,
426+
"labels": [l.value for l in node.labels if l != label], # noqa: E741
427+
"vector": node.properties.pop("embedding", None),
428+
"properties": node.properties,
429+
}
430+
for node in nodes
431+
]
432+
}
360433
return query_string, parameters
361434

362435
@staticmethod
363-
def _create_edges_query(edges: list[_Edge], relationship: Relationship) -> tuple[str, dict]:
436+
def _create_edges_query(
437+
edges: list[_Edge],
438+
relationship: Relationship,
439+
source_label: Label,
440+
destination_label: Label,
441+
) -> tuple[str, dict]:
364442
logger.info(f"Preparing MERGE query for {len(edges)} {relationship} relationships.")
365443
query_string = f"""
366444
UNWIND $edges AS edge
367-
MATCH (u {{id: edge.source}})
368-
MATCH (v {{id: edge.destination}})
369-
MERGE (u)-[:{relationship.value}]->(v)
445+
MATCH (u: `{source_label.value}` {{id: edge.source}})
446+
MATCH (v: `{destination_label.value}` {{id: edge.destination}})
447+
MERGE (u)-[:`{relationship.value}`]->(v)
370448
"""
371449
parameters = {
372450
"edges": [
373-
{"source": edge.source_id, "destination": edge.destination_id} for edge in edges
451+
{"source": edge.source.id_, "destination": edge.destination.id_} for edge in edges
374452
]
375453
}
376454
return query_string, parameters
377455

456+
def _get_embedding_dimensions(self, graph_data: _GraphData) -> int | None:
457+
"""Embedding dimensions inferred from chunk nodes or None if it can't be determined."""
458+
for node in graph_data.nodes:
459+
if Label.CHUNK in node.labels and "embeddings" in node.properties:
460+
return len(node.properties["embeddings"])
461+
462+
return None
463+
378464

379465
neo4j_destination_entry = DestinationRegistryEntry(
380466
connection_config=Neo4jConnectionConfig,

0 commit comments

Comments
 (0)