88from dataclasses import dataclass
99from enum import Enum
1010from pathlib import Path
11- from typing import TYPE_CHECKING , Any , AsyncGenerator , Optional
11+ from typing import TYPE_CHECKING , Any , AsyncGenerator , Literal , Optional
1212
13- from pydantic import BaseModel , ConfigDict , Field , Secret
13+ from pydantic import BaseModel , ConfigDict , Field , Secret , field_validator
1414
1515from unstructured_ingest .error import DestinationConnectionError
1616from unstructured_ingest .logger import logger
3030 DestinationRegistryEntry ,
3131)
3232
33+ SimilarityFunction = Literal ["cosine" ]
34+
3335if TYPE_CHECKING :
3436 from neo4j import AsyncDriver , Auth
3537 from networkx import Graph , MultiDiGraph
@@ -44,9 +46,9 @@ class Neo4jAccessConfig(AccessConfig):
4446class Neo4jConnectionConfig (ConnectionConfig ):
4547 access_config : Secret [Neo4jAccessConfig ]
4648 connector_type : str = Field (default = CONNECTOR_TYPE , init = False )
47- username : str
49+ username : str = Field ( default = "neo4j" )
4850 uri : str = Field (description = "Neo4j Connection URI <scheme>://<host>:<port>" )
49- database : str = Field (description = "Name of the target database" )
51+ database : str = Field (default = "neo4j" , description = "Name of the target database" )
5052
5153 @requires_dependencies (["neo4j" ], extras = "neo4j" )
5254 @asynccontextmanager
@@ -186,8 +188,8 @@ def from_nx(cls, nx_graph: "MultiDiGraph") -> _GraphData:
186188 nodes = list (nx_graph .nodes ())
187189 edges = [
188190 _Edge (
189- source_id = u . id_ ,
190- destination_id = v . id_ ,
191+ source = u ,
192+ destination = v ,
191193 relationship = Relationship (data_dict ["relationship" ]),
192194 )
193195 for u , v , data_dict in nx_graph .edges (data = True )
@@ -198,19 +200,30 @@ def from_nx(cls, nx_graph: "MultiDiGraph") -> _GraphData:
198200class _Node (BaseModel ):
199201 model_config = ConfigDict ()
200202
201- id_ : str = Field (default_factory = lambda : str (uuid .uuid4 ()))
202- labels : list [Label ] = Field (default_factory = list )
203+ labels : list [Label ]
203204 properties : dict = Field (default_factory = dict )
205+ id_ : str = Field (default_factory = lambda : str (uuid .uuid4 ()))
204206
205207 def __hash__ (self ):
206208 return hash (self .id_ )
207209
210+ @property
211+ def main_label (self ) -> Label :
212+ return self .labels [0 ]
213+
214+ @classmethod
215+ @field_validator ("labels" , mode = "after" )
216+ def require_at_least_one_label (cls , value : list [Label ]) -> list [Label ]:
217+ if not value :
218+ raise ValueError ("Node must have at least one label." )
219+ return value
220+
208221
209222class _Edge (BaseModel ):
210223 model_config = ConfigDict ()
211224
212- source_id : str
213- destination_id : str
225+ source : _Node
226+ destination : _Node
214227 relationship : Relationship
215228
216229
@@ -229,7 +242,14 @@ class Relationship(Enum):
229242
230243class Neo4jUploaderConfig (UploaderConfig ):
231244 batch_size : int = Field (
232- default = 100 , description = "Maximal number of nodes/relationships created per transaction."
245+ default = 1000 , description = "Maximal number of nodes/relationships created per transaction."
246+ )
247+ similarity_function : SimilarityFunction = Field (
248+ default = "cosine" ,
249+ description = "Vector similarity function used to create index on Chunk nodes" ,
250+ )
251+ create_destination : bool = Field (
252+ default = True , description = "Create destination if it does not exist"
233253 )
234254
235255
@@ -257,6 +277,13 @@ async def run_async(self, path: Path, file_data: FileData, **kwargs) -> None: #
257277 graph_data = _GraphData .model_validate (staged_data )
258278 async with self .connection_config .get_client () as client :
259279 await self ._create_uniqueness_constraints (client )
280+ embedding_dimensions = self ._get_embedding_dimensions (graph_data )
281+ if embedding_dimensions and self .upload_config .create_destination :
282+ await self ._create_vector_index (
283+ client ,
284+ dimensions = embedding_dimensions ,
285+ similarity_function = self .upload_config .similarity_function ,
286+ )
260287 await self ._delete_old_data_if_exists (file_data , client = client )
261288 await self ._merge_graph (graph_data = graph_data , client = client )
262289
@@ -274,13 +301,33 @@ async def _create_uniqueness_constraints(self, client: AsyncDriver) -> None:
274301 """
275302 )
276303
304+ async def _create_vector_index (
305+ self , client : AsyncDriver , dimensions : int , similarity_function : SimilarityFunction
306+ ) -> None :
307+ label = Label .CHUNK
308+ logger .info (
309+ f"Creating index on nodes labeled '{ label .value } ' if it does not already exist."
310+ )
311+ index_name = f"{ label .value .lower ()} _vector"
312+ await client .execute_query (
313+ f"""
314+ CREATE VECTOR INDEX { index_name } IF NOT EXISTS
315+ FOR (n:{ label .value } ) ON n.embedding
316+ OPTIONS {{indexConfig: {{
317+ `vector.similarity_function`: '{ similarity_function } ',
318+ `vector.dimensions`: { dimensions } }}
319+ }}
320+ """
321+ )
322+
277323 async def _delete_old_data_if_exists (self , file_data : FileData , client : AsyncDriver ) -> None :
278324 logger .info (f"Deleting old data for the record '{ file_data .identifier } ' (if present)." )
279325 _ , summary , _ = await client .execute_query (
280326 f"""
281- MATCH (n: { Label .DOCUMENT .value } {{id: $identifier}})
282- MATCH (n)--(m: { Label .CHUNK .value } |{ Label .UNSTRUCTURED_ELEMENT .value } )
283- DETACH DELETE m""" ,
327+ MATCH (n: `{ Label .DOCUMENT .value } ` {{id: $identifier}})
328+ MATCH (n)--(m: `{ Label .CHUNK .value } `|`{ Label .UNSTRUCTURED_ELEMENT .value } `)
329+ DETACH DELETE m
330+ DETACH DELETE n""" ,
284331 identifier = file_data .identifier ,
285332 )
286333 logger .info (
@@ -289,33 +336,39 @@ async def _delete_old_data_if_exists(self, file_data: FileData, client: AsyncDri
289336 )
290337
291338 async def _merge_graph (self , graph_data : _GraphData , client : AsyncDriver ) -> None :
292- nodes_by_labels : defaultdict [tuple [ Label , ...] , list [_Node ]] = defaultdict (list )
339+ nodes_by_labels : defaultdict [Label , list [_Node ]] = defaultdict (list )
293340 for node in graph_data .nodes :
294- nodes_by_labels [tuple (node .labels )].append (node )
295-
341+ nodes_by_labels [node .main_label ].append (node )
296342 logger .info (f"Merging { len (graph_data .nodes )} graph nodes." )
297343 # NOTE: Processed in parallel as there's no overlap between accessed nodes
298344 await self ._execute_queries (
299345 [
300- self ._create_nodes_query (nodes_batch , labels )
301- for labels , nodes in nodes_by_labels .items ()
346+ self ._create_nodes_query (nodes_batch , label )
347+ for label , nodes in nodes_by_labels .items ()
302348 for nodes_batch in batch_generator (nodes , batch_size = self .upload_config .batch_size )
303349 ],
304350 client = client ,
305351 in_parallel = True ,
306352 )
307353 logger .info (f"Finished merging { len (graph_data .nodes )} graph nodes." )
308354
309- edges_by_relationship : defaultdict [Relationship , list [_Edge ]] = defaultdict (list )
355+ edges_by_relationship : defaultdict [tuple [Relationship , Label , Label ], list [_Edge ]] = (
356+ defaultdict (list )
357+ )
310358 for edge in graph_data .edges :
311- edges_by_relationship [edge .relationship ].append (edge )
359+ key = (edge .relationship , edge .source .main_label , edge .destination .main_label )
360+ edges_by_relationship [key ].append (edge )
312361
313362 logger .info (f"Merging { len (graph_data .edges )} graph relationships (edges)." )
314363 # NOTE: Processed sequentially to avoid queries locking node access to one another
315364 await self ._execute_queries (
316365 [
317- self ._create_edges_query (edges_batch , relationship )
318- for relationship , edges in edges_by_relationship .items ()
366+ self ._create_edges_query (edges_batch , relationship , source_label , destination_label )
367+ for (
368+ relationship ,
369+ source_label ,
370+ destination_label ,
371+ ), edges in edges_by_relationship .items ()
319372 for edges_batch in batch_generator (edges , batch_size = self .upload_config .batch_size )
320373 ],
321374 client = client ,
@@ -328,53 +381,86 @@ async def _execute_queries(
328381 client : AsyncDriver ,
329382 in_parallel : bool = False ,
330383 ) -> None :
384+ from neo4j import EagerResult
385+
386+ results : list [EagerResult ] = []
387+ logger .info (
388+ f"Executing { len (queries_with_parameters )} "
389+ + f"{ 'parallel' if in_parallel else 'sequential' } Cypher statements."
390+ )
331391 if in_parallel :
332- logger .info (f"Executing { len (queries_with_parameters )} queries in parallel." )
333- await asyncio .gather (
392+ results = await asyncio .gather (
334393 * [
335394 client .execute_query (query , parameters_ = parameters )
336395 for query , parameters in queries_with_parameters
337396 ]
338397 )
339- logger .info ("Finished executing parallel queries." )
340398 else :
341- logger .info (f"Executing { len (queries_with_parameters )} queries sequentially." )
342399 for i , (query , parameters ) in enumerate (queries_with_parameters ):
343- logger .info (f"Query #{ i } started." )
344- await client .execute_query (query , parameters_ = parameters )
345- logger .info (f"Query #{ i } finished." )
346- logger .info (
347- f"Finished executing all ({ len (queries_with_parameters )} ) sequential queries."
348- )
400+ logger .info (f"Statement #{ i } started." )
401+ results .append (await client .execute_query (query , parameters_ = parameters ))
402+ logger .info (f"Statement #{ i } finished." )
403+ nodeCount = sum ([res .summary .counters .nodes_created for res in results ])
404+ relCount = sum ([res .summary .counters .relationships_created for res in results ])
405+ logger .info (
406+ f"Finished executing all ({ len (queries_with_parameters )} ) "
407+ + f"{ 'parallel' if in_parallel else 'sequential' } Cypher statements. "
408+ + f"Created { nodeCount } nodes, { relCount } relationships."
409+ )
349410
350411 @staticmethod
351- def _create_nodes_query (nodes : list [_Node ], labels : tuple [Label , ...]) -> tuple [str , dict ]:
352- labels_string = ", " .join ([label .value for label in labels ])
353- logger .info (f"Preparing MERGE query for { len (nodes )} nodes labeled '{ labels_string } '." )
412+ def _create_nodes_query (nodes : list [_Node ], label : Label ) -> tuple [str , dict ]:
413+ logger .info (f"Preparing MERGE query for { len (nodes )} nodes labeled '{ label } '." )
354414 query_string = f"""
355415 UNWIND $nodes AS node
356- MERGE (n: { labels_string } {{id: node.id}})
416+ MERGE (n: ` { label . value } ` {{id: node.id}})
357417 SET n += node.properties
418+ SET n:$(node.labels)
419+ WITH * WHERE node.vector IS NOT NULL
420+ CALL db.create.setNodeVectorProperty(n, 'embedding', node.vector)
358421 """
359- parameters = {"nodes" : [{"id" : node .id_ , "properties" : node .properties } for node in nodes ]}
422+ parameters = {
423+ "nodes" : [
424+ {
425+ "id" : node .id_ ,
426+ "labels" : [l .value for l in node .labels if l != label ], # noqa: E741
427+ "vector" : node .properties .pop ("embedding" , None ),
428+ "properties" : node .properties ,
429+ }
430+ for node in nodes
431+ ]
432+ }
360433 return query_string , parameters
361434
362435 @staticmethod
363- def _create_edges_query (edges : list [_Edge ], relationship : Relationship ) -> tuple [str , dict ]:
436+ def _create_edges_query (
437+ edges : list [_Edge ],
438+ relationship : Relationship ,
439+ source_label : Label ,
440+ destination_label : Label ,
441+ ) -> tuple [str , dict ]:
364442 logger .info (f"Preparing MERGE query for { len (edges )} { relationship } relationships." )
365443 query_string = f"""
366444 UNWIND $edges AS edge
367- MATCH (u {{id: edge.source}})
368- MATCH (v {{id: edge.destination}})
369- MERGE (u)-[:{ relationship .value } ]->(v)
445+ MATCH (u: ` { source_label . value } ` {{id: edge.source}})
446+ MATCH (v: ` { destination_label . value } ` {{id: edge.destination}})
447+ MERGE (u)-[:` { relationship .value } ` ]->(v)
370448 """
371449 parameters = {
372450 "edges" : [
373- {"source" : edge .source_id , "destination" : edge .destination_id } for edge in edges
451+ {"source" : edge .source . id_ , "destination" : edge .destination . id_ } for edge in edges
374452 ]
375453 }
376454 return query_string , parameters
377455
456+ def _get_embedding_dimensions (self , graph_data : _GraphData ) -> int | None :
457+ """Embedding dimensions inferred from chunk nodes or None if it can't be determined."""
458+ for node in graph_data .nodes :
459+ if Label .CHUNK in node .labels and "embeddings" in node .properties :
460+ return len (node .properties ["embeddings" ])
461+
462+ return None
463+
378464
379465neo4j_destination_entry = DestinationRegistryEntry (
380466 connection_config = Neo4jConnectionConfig ,
0 commit comments