99from pathlib import Path
1010from typing import TYPE_CHECKING , Any , AsyncGenerator , Literal , Optional
1111
12- from pydantic import BaseModel , ConfigDict , Field , Secret , field_validator
12+ from pydantic import BaseModel , ConfigDict , Field , Secret , ValidationError , field_validator
1313
14+ from unstructured_ingest .data_types .entities import EntitiesData , Entity , EntityRelationship
1415from unstructured_ingest .data_types .file_data import FileData
1516from unstructured_ingest .error import DestinationConnectionError
1617from unstructured_ingest .interfaces import (
@@ -97,7 +98,6 @@ def run( # type: ignore
9798 ** kwargs : Any ,
9899 ) -> Path :
99100 elements = get_json_data (elements_filepath )
100-
101101 nx_graph = self ._create_lexical_graph (
102102 elements , self ._create_document_node (file_data = file_data )
103103 )
@@ -109,28 +109,54 @@ def run( # type: ignore
109109
110110 return output_filepath
111111
112- def _add_entities (self , element : dict , graph : "Graph" , element_node : _Node ) -> None :
113- entities = element .get ("metadata" , {}).get ("entities" , [])
114- if not entities :
115- return None
116- if not isinstance (entities , list ):
117- return None
118-
112+ def _add_entities (self , entities : list [Entity ], graph : "Graph" , element_node : _Node ) -> None :
119113 for entity in entities :
120- if not isinstance (entity , dict ):
121- continue
122- if "entity" not in entity or "type" not in entity :
123- continue
124114 entity_node = _Node (
125- labels = [Label .ENTITY ], properties = {"id" : entity [ " entity" ] }, id_ = entity [ " entity" ]
115+ labels = [Label .ENTITY ], properties = {"id" : entity . entity }, id_ = entity . entity
126116 )
127117 graph .add_edge (
128118 entity_node ,
129- _Node (labels = [Label .ENTITY ], properties = {"id" : entity [ " type" ] }, id_ = entity [ " type" ] ),
119+ _Node (labels = [Label .ENTITY ], properties = {"id" : entity . type }, id_ = entity . type ),
130120 relationship = Relationship .ENTITY_TYPE ,
131121 )
132122 graph .add_edge (element_node , entity_node , relationship = Relationship .HAS_ENTITY )
133123
124+ def _add_entity_relationships (
125+ self , relationships : list [EntityRelationship ], graph : "Graph"
126+ ) -> None :
127+ for relationship in relationships :
128+ from_node = _Node (
129+ labels = [Label .ENTITY ],
130+ properties = {"id" : relationship .from_ },
131+ id_ = relationship .from_ ,
132+ )
133+ to_node = _Node (
134+ labels = [Label .ENTITY ], properties = {"id" : relationship .to }, id_ = relationship .to
135+ )
136+ graph .add_edge (from_node , to_node , relationship = relationship .relationship )
137+
138+ def _add_entity_data (self , element : dict , graph : "Graph" , element_node : _Node ) -> None :
139+ entities = element .get ("metadata" , {}).get ("entities" , {})
140+ if not entities :
141+ return None
142+ try :
143+ if isinstance (entities , list ):
144+ self ._add_entities (
145+ [Entity .model_validate (e ) for e in entities if isinstance (e , dict )],
146+ graph ,
147+ element_node ,
148+ )
149+ elif isinstance (entities , dict ):
150+ entity_data = EntitiesData .model_validate (entities )
151+ self ._add_entities (entity_data .items , graph , element_node )
152+ self ._add_entity_relationships (entity_data .relationships , graph )
153+ except ValidationError :
154+ logger .warning (
155+ "Failed to add entities to the graph. "
156+ "Please check the format of the entities in the input data."
157+ )
158+ return None
159+
134160 def _create_lexical_graph (self , elements : list [dict ], document_node : _Node ) -> "Graph" :
135161 import networkx as nx
136162
@@ -149,7 +175,7 @@ def _create_lexical_graph(self, elements: list[dict], document_node: _Node) -> "
149175 previous_node = element_node
150176 graph .add_edge (element_node , document_node , relationship = Relationship .PART_OF_DOCUMENT )
151177
152- self ._add_entities (element , graph , element_node )
178+ self ._add_entity_data (element , graph , element_node )
153179
154180 if self ._is_chunk (element ):
155181 for origin_element in format_and_truncate_orig_elements (element , include_text = True ):
@@ -165,7 +191,7 @@ def _create_lexical_graph(self, elements: list[dict], document_node: _Node) -> "
165191 document_node ,
166192 relationship = Relationship .PART_OF_DOCUMENT ,
167193 )
168- self ._add_entities (origin_element , graph , origin_element_node )
194+ self ._add_entity_data (origin_element , graph , origin_element_node )
169195
170196 return graph
171197
@@ -208,7 +234,9 @@ def from_nx(cls, nx_graph: "MultiDiGraph") -> _GraphData:
208234 _Edge (
209235 source = u ,
210236 destination = v ,
211- relationship = Relationship (data_dict ["relationship" ]),
237+ relationship = Relationship (data_dict ["relationship" ])
238+ if data_dict ["relationship" ] in Relationship
239+ else data_dict ["relationship" ],
212240 )
213241 for u , v , data_dict in nx_graph .edges (data = True )
214242 ]
@@ -242,7 +270,7 @@ class _Edge(BaseModel):
242270
243271 source : _Node
244272 destination : _Node
245- relationship : Relationship
273+ relationship : Relationship | str
246274
247275
248276class Label (Enum ):
@@ -380,7 +408,7 @@ async def _merge_graph(self, graph_data: _GraphData, client: AsyncDriver) -> Non
380408 )
381409 logger .info (f"Finished merging { len (graph_data .nodes )} graph nodes." )
382410
383- edges_by_relationship : defaultdict [tuple [Relationship , Label , Label ], list [_Edge ]] = (
411+ edges_by_relationship : defaultdict [tuple [Relationship | str , Label , Label ], list [_Edge ]] = (
384412 defaultdict (list )
385413 )
386414 for edge in graph_data .edges :
@@ -463,16 +491,19 @@ def _create_nodes_query(nodes: list[_Node], label: Label) -> tuple[str, dict]:
463491 @staticmethod
464492 def _create_edges_query (
465493 edges : list [_Edge ],
466- relationship : Relationship ,
494+ relationship : Relationship | str ,
467495 source_label : Label ,
468496 destination_label : Label ,
469497 ) -> tuple [str , dict ]:
470498 logger .info (f"Preparing MERGE query for { len (edges )} { relationship } relationships." )
499+ relationship = (
500+ relationship .value if isinstance (relationship , Relationship ) else relationship
501+ )
471502 query_string = f"""
472503 UNWIND $edges AS edge
473504 MATCH (u: `{ source_label .value } ` {{id: edge.source}})
474505 MATCH (v: `{ destination_label .value } ` {{id: edge.destination}})
475- MERGE (u)-[:`{ relationship . value } `]->(v)
506+ MERGE (u)-[:`{ relationship } `]->(v)
476507 """
477508 parameters = {
478509 "edges" : [
0 commit comments