Skip to content

Commit 037ab47

Browse files
committed
Add export_graph()
1 parent 282ad27 commit 037ab47

File tree

1 file changed

+30
-7
lines changed

1 file changed

+30
-7
lines changed

entity_graph/graph_extractor/entities_graph_extractor.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ def __init__(
3939
self.local = False
4040
self.collection = collection
4141

42-
def load_table_from_file(self, source, filename, table_name, data_type):
42+
def load_table_from_file(
43+
self, source: str | dict, filename: str, table_name: str, data_type: str
44+
):
4345
self._add_file_node(filename)
4446
self._add_table(self._read_source(source), filename, table_name, data_type)
4547
return
@@ -96,7 +98,9 @@ def _add_table(self, json_data, file_name, table_name, data_type):
9698

9799
def _add_identifiers(self):
98100
# Create new identifiers
99-
for name, table_name, cols, fill_values in self.extraction_plan.get("identifiers", []):
101+
for name, table_name, cols, fill_values in self.extraction_plan.get(
102+
"identifiers", []
103+
):
100104
new_id(
101105
self.entities_graph_manager,
102106
name,
@@ -107,7 +111,9 @@ def _add_identifiers(self):
107111
)
108112

109113
# Link identifiers
110-
for name, table_name, cols, fill_values in self.extraction_plan.get("identifiers_links", []):
114+
for name, table_name, cols, fill_values in self.extraction_plan.get(
115+
"identifiers_links", []
116+
):
111117
link_id(
112118
self.entities_graph_manager,
113119
name,
@@ -118,7 +124,12 @@ def _add_identifiers(self):
118124
)
119125

120126
def _make_instances(self):
121-
for id_name, table_name, do_hierarchy, override_cols in self.extraction_plan.get("instances_creation", []):
127+
for (
128+
id_name,
129+
table_name,
130+
do_hierarchy,
131+
override_cols,
132+
) in self.extraction_plan.get("instances_creation", []):
122133
# Make sure to pass the collection parameter
123134
create_instances(
124135
self.entities_graph_manager,
@@ -144,7 +155,9 @@ def _enrichment_matching(self):
144155
collection=self.collection,
145156
)
146157

147-
for table_name, label1, label2, new_labels in self.extraction_plan.get("enrichments", []):
158+
for table_name, label1, label2, new_labels in self.extraction_plan.get(
159+
"enrichments", []
160+
):
148161
# Pass the collection parameter
149162
enrich_from_table_name(
150163
self.entities_graph_manager,
@@ -301,7 +314,7 @@ def _instances_context(self):
301314
collection=self.collection,
302315
)
303316

304-
def extract_entities_graph(self, stages, extraction_plan_config=None):
317+
def extract_entities_graph_by_stage(self, stages, extraction_plan_config=None):
305318
self._make_extraction_plan(extraction_plan_config)
306319
# Define a dictionary mapping stage names to their corresponding methods
307320
stages_methods = {
@@ -326,7 +339,9 @@ def extract_entities_graph(self, stages, extraction_plan_config=None):
326339
return self
327340

328341
# Update the extract_entities_graph2 method to check for completed steps
329-
def extract_entities_graph2(self, extraction_plan_config=None):
342+
def extract_entities_graph(
343+
self, extraction_plan_config: dict[str, list[tuple]] | None = None
344+
):
330345
"""
331346
Extract entities graph - second phase.
332347
@@ -364,3 +379,11 @@ def extract_entities_graph2(self, extraction_plan_config=None):
364379
self.stages_done.add("_instances_context")
365380

366381
return self
382+
383+
def export_graph(self) -> tuple[list[str], list[tuple[str, str]], dict[str, dict]]:
384+
"""
385+
Export objects sub-graph to json.
386+
387+
:return: Tuple of nodes, edges, and metadata
388+
"""
389+
return self.entities_graph_manager.export_objects_graph(collection="default")

0 commit comments

Comments
 (0)