66from typing import Any
77from uuid import uuid4
88
9- import networkx as nx
109import pandas as pd
1110from datashaper import (
1211 AsyncType ,
1312 VerbCallbacks ,
1413)
1514
1615from graphrag .cache .pipeline_cache import PipelineCache
17- from graphrag .index .operations .create_graph import create_graph
1816from graphrag .index .operations .extract_entities import extract_entities
19- from graphrag .index .operations .snapshot import snapshot
20- from graphrag .index .operations .snapshot_graphml import snapshot_graphml
2117from graphrag .index .operations .summarize_descriptions import (
2218 summarize_descriptions ,
2319)
24- from graphrag .storage .pipeline_storage import PipelineStorage
2520
2621
2722async def extract_graph (
2823 text_units : pd .DataFrame ,
2924 callbacks : VerbCallbacks ,
3025 cache : PipelineCache ,
31- storage : PipelineStorage ,
3226 extraction_strategy : dict [str , Any ] | None = None ,
3327 extraction_num_threads : int = 4 ,
3428 extraction_async_mode : AsyncType = AsyncType .AsyncIO ,
3529 entity_types : list [str ] | None = None ,
3630 summarization_strategy : dict [str , Any ] | None = None ,
3731 summarization_num_threads : int = 4 ,
38- snapshot_graphml_enabled : bool = False ,
39- snapshot_transient_enabled : bool = False ,
4032) -> tuple [pd .DataFrame , pd .DataFrame ]:
4133 """All the steps to create the base entity graph."""
4234 # this returns a graph for each text unit, to be merged later
43- entity_dfs , relationship_dfs = await extract_entities (
35+ entities , relationships = await extract_entities (
4436 text_units ,
4537 callbacks ,
4638 cache ,
@@ -52,87 +44,38 @@ async def extract_graph(
5244 num_threads = extraction_num_threads ,
5345 )
5446
55- if not _validate_data (entity_dfs ):
47+ if not _validate_data (entities ):
5648 error_msg = "Entity Extraction failed. No entities detected during extraction."
5749 callbacks .error (error_msg )
5850 raise ValueError (error_msg )
5951
60- if not _validate_data (relationship_dfs ):
52+ if not _validate_data (relationships ):
6153 error_msg = (
6254 "Entity Extraction failed. No relationships detected during extraction."
6355 )
6456 callbacks .error (error_msg )
6557 raise ValueError (error_msg )
6658
67- merged_entities = _merge_entities (entity_dfs )
68- merged_relationships = _merge_relationships (relationship_dfs )
69-
7059 entity_summaries , relationship_summaries = await summarize_descriptions (
71- merged_entities ,
72- merged_relationships ,
60+ entities ,
61+ relationships ,
7362 callbacks ,
7463 cache ,
7564 strategy = summarization_strategy ,
7665 num_threads = summarization_num_threads ,
7766 )
7867
79- base_relationship_edges = _prep_edges (merged_relationships , relationship_summaries )
80-
81- graph = create_graph (base_relationship_edges )
82-
83- base_entity_nodes = _prep_nodes (merged_entities , entity_summaries , graph )
84-
85- if snapshot_graphml_enabled :
86- # todo: extract graphs at each level, and add in meta like descriptions
87- await snapshot_graphml (
88- graph ,
89- name = "graph" ,
90- storage = storage ,
91- )
68+ base_relationship_edges = _prep_edges (relationships , relationship_summaries )
9269
93- if snapshot_transient_enabled :
94- await snapshot (
95- base_entity_nodes ,
96- name = "base_entity_nodes" ,
97- storage = storage ,
98- formats = ["parquet" ],
99- )
100- await snapshot (
101- base_relationship_edges ,
102- name = "base_relationship_edges" ,
103- storage = storage ,
104- formats = ["parquet" ],
105- )
70+ base_entity_nodes = _prep_nodes (entities , entity_summaries )
10671
10772 return (base_entity_nodes , base_relationship_edges )
10873
10974
110- def _merge_entities (entity_dfs ) -> pd .DataFrame :
111- all_entities = pd .concat (entity_dfs , ignore_index = True )
112- return (
113- all_entities .groupby (["name" , "type" ], sort = False )
114- .agg ({"description" : list , "source_id" : list })
115- .reset_index ()
116- )
117-
118-
119- def _merge_relationships (relationship_dfs ) -> pd .DataFrame :
120- all_relationships = pd .concat (relationship_dfs , ignore_index = False )
121- return (
122- all_relationships .groupby (["source" , "target" ], sort = False )
123- .agg ({"description" : list , "source_id" : list , "weight" : "sum" })
124- .reset_index ()
125- )
126-
127-
128- def _prep_nodes (entities , summaries , graph ) -> pd .DataFrame :
129- degrees_df = _compute_degree (graph )
75+ def _prep_nodes (entities , summaries ) -> pd .DataFrame :
13076 entities .drop (columns = ["description" ], inplace = True )
131- nodes = (
132- entities .merge (summaries , on = "name" , how = "left" )
133- .merge (degrees_df , on = "name" )
134- .drop_duplicates (subset = "name" )
135- .rename (columns = {"name" : "title" , "source_id" : "text_unit_ids" })
77+ nodes = entities .merge (summaries , on = "title" , how = "left" ).drop_duplicates (
78+ subset = "title"
13679 )
13780 nodes = nodes .loc [nodes ["title" ].notna ()].reset_index ()
13881 nodes ["human_readable_id" ] = nodes .index
@@ -145,22 +88,12 @@ def _prep_edges(relationships, summaries) -> pd.DataFrame:
14588 relationships .drop (columns = ["description" ])
14689 .drop_duplicates (subset = ["source" , "target" ])
14790 .merge (summaries , on = ["source" , "target" ], how = "left" )
148- .rename (columns = {"source_id" : "text_unit_ids" })
14991 )
15092 edges ["human_readable_id" ] = edges .index
15193 edges ["id" ] = edges ["human_readable_id" ].apply (lambda _x : str (uuid4 ()))
15294 return edges
15395
15496
155- def _compute_degree (graph : nx .Graph ) -> pd .DataFrame :
156- return pd .DataFrame ([
157- {"name" : node , "degree" : int (degree )}
158- for node , degree in graph .degree # type: ignore
159- ])
160-
161-
162- def _validate_data (df_list : list [pd .DataFrame ]) -> bool :
163- """Validate that the dataframe list is valid. At least one dataframe must contain data."""
164- return any (
165- len (df ) > 0 for df in df_list
166- ) # Check for len, not .empty, as the dfs have schemas in some cases
97+ def _validate_data (df : pd .DataFrame ) -> bool :
98+ """Validate that the dataframe has data."""
99+ return len (df ) > 0
0 commit comments