44"""All the steps to create the base entity graph."""
55
66from typing import Any , cast
7+ from uuid import uuid4
78
9+ import networkx as nx
810import pandas as pd
911from datashaper import (
1012 AsyncType ,
1315
1416from graphrag .cache .pipeline_cache import PipelineCache
1517from graphrag .index .operations .cluster_graph import cluster_graph
16- from graphrag .index .operations .embed_graph import embed_graph
18+ from graphrag .index .operations .create_graph import create_graph
1719from graphrag .index .operations .extract_entities import extract_entities
18- from graphrag .index .operations .merge_graphs import merge_graphs
1920from graphrag .index .operations .snapshot import snapshot
2021from graphrag .index .operations .snapshot_graphml import snapshot_graphml
21- from graphrag .index .operations .snapshot_rows import snapshot_rows
2222from graphrag .index .operations .summarize_descriptions import (
2323 summarize_descriptions ,
2424)
@@ -30,23 +30,20 @@ async def create_base_entity_graph(
3030 callbacks : VerbCallbacks ,
3131 cache : PipelineCache ,
3232 storage : PipelineStorage ,
33+ runtime_storage : PipelineStorage ,
3334 clustering_strategy : dict [str , Any ],
3435 extraction_strategy : dict [str , Any ] | None = None ,
3536 extraction_num_threads : int = 4 ,
3637 extraction_async_mode : AsyncType = AsyncType .AsyncIO ,
3738 entity_types : list [str ] | None = None ,
38- node_merge_config : dict [str , Any ] | None = None ,
39- edge_merge_config : dict [str , Any ] | None = None ,
4039 summarization_strategy : dict [str , Any ] | None = None ,
4140 summarization_num_threads : int = 4 ,
42- embedding_strategy : dict [str , Any ] | None = None ,
4341 snapshot_graphml_enabled : bool = False ,
44- snapshot_raw_entities_enabled : bool = False ,
4542 snapshot_transient_enabled : bool = False ,
46- ) -> pd . DataFrame :
43+ ) -> None :
4744 """All the steps to create the base entity graph."""
4845 # this returns a graph for each text unit, to be merged later
49- entities , entity_graphs = await extract_entities (
46+ entity_dfs , relationship_dfs = await extract_entities (
5047 text_units ,
5148 callbacks ,
5249 cache ,
@@ -55,89 +52,122 @@ async def create_base_entity_graph(
5552 strategy = extraction_strategy ,
5653 async_mode = extraction_async_mode ,
5754 entity_types = entity_types ,
58- to = "entities" ,
5955 num_threads = extraction_num_threads ,
6056 )
6157
62- merged_graph = merge_graphs (
63- entity_graphs ,
64- callbacks ,
65- node_operations = node_merge_config ,
66- edge_operations = edge_merge_config ,
67- )
58+ merged_entities = _merge_entities (entity_dfs )
59+ merged_relationships = _merge_relationships (relationship_dfs )
6860
69- summarized = await summarize_descriptions (
70- merged_graph ,
61+ entity_summaries , relationship_summaries = await summarize_descriptions (
62+ merged_entities ,
63+ merged_relationships ,
7164 callbacks ,
7265 cache ,
7366 strategy = summarization_strategy ,
7467 num_threads = summarization_num_threads ,
7568 )
7669
77- clustered = cluster_graph (
78- summarized ,
79- callbacks ,
80- column = "entity_graph" ,
70+ base_relationship_edges = _prep_edges (merged_relationships , relationship_summaries )
71+
72+ graph = create_graph (base_relationship_edges )
73+
74+ base_entity_nodes = _prep_nodes (merged_entities , entity_summaries , graph )
75+
76+ communities = cluster_graph (
77+ graph ,
8178 strategy = clustering_strategy ,
82- to = "clustered_graph" ,
83- level_to = "level" ,
8479 )
8580
86- if embedding_strategy :
87- clustered ["embeddings" ] = await embed_graph (
88- clustered ,
89- callbacks ,
90- column = "clustered_graph" ,
91- strategy = embedding_strategy ,
92- )
81+ base_communities = _prep_communities (communities )
9382
94- if snapshot_raw_entities_enabled :
95- await snapshot (
96- entities ,
97- name = "raw_extracted_entities" ,
98- storage = storage ,
99- formats = ["json" ],
100- )
83+ await runtime_storage .set ("base_entity_nodes" , base_entity_nodes )
84+ await runtime_storage .set ("base_relationship_edges" , base_relationship_edges )
85+ await runtime_storage .set ("base_communities" , base_communities )
10186
10287 if snapshot_graphml_enabled :
88+ # todo: extract graphs at each level, and add in meta like descriptions
10389 await snapshot_graphml (
104- merged_graph ,
105- name = "merged_graph " ,
90+ graph ,
91+ name = "graph " ,
10692 storage = storage ,
10793 )
108- await snapshot_graphml (
109- summarized ,
110- name = "summarized_graph" ,
94+
95+ if snapshot_transient_enabled :
96+ await snapshot (
97+ base_entity_nodes ,
98+ name = "base_entity_nodes" ,
11199 storage = storage ,
100+ formats = ["parquet" ],
112101 )
113- await snapshot_rows (
114- clustered ,
115- column = "clustered_graph" ,
116- base_name = "clustered_graph" ,
102+ await snapshot (
103+ base_relationship_edges ,
104+ name = "base_relationship_edges" ,
117105 storage = storage ,
118- formats = [{ "format" : "text" , "extension" : "graphml" } ],
106+ formats = ["parquet" ],
119107 )
120- if embedding_strategy :
121- await snapshot_rows (
122- clustered ,
123- column = "entity_graph" ,
124- base_name = "embedded_graph" ,
125- storage = storage ,
126- formats = [{"format" : "text" , "extension" : "graphml" }],
127- )
128-
129- final_columns = ["level" , "clustered_graph" ]
130- if embedding_strategy :
131- final_columns .append ("embeddings" )
132-
133- output = cast (pd .DataFrame , clustered [final_columns ])
134-
135- if snapshot_transient_enabled :
136108 await snapshot (
137- output ,
138- name = "create_base_entity_graph " ,
109+ base_communities ,
110+ name = "base_communities " ,
139111 storage = storage ,
140112 formats = ["parquet" ],
141113 )
142114
143- return output
115+
116+ def _merge_entities (entity_dfs ) -> pd .DataFrame :
117+ all_entities = pd .concat (entity_dfs , ignore_index = True )
118+ return (
119+ all_entities .groupby (["name" , "type" ], sort = False )
120+ .agg ({"description" : list , "source_id" : list })
121+ .reset_index ()
122+ )
123+
124+
125+ def _merge_relationships (relationship_dfs ) -> pd .DataFrame :
126+ all_relationships = pd .concat (relationship_dfs , ignore_index = False )
127+ return (
128+ all_relationships .groupby (["source" , "target" ], sort = False )
129+ .agg ({"description" : list , "source_id" : list , "weight" : "sum" })
130+ .reset_index ()
131+ )
132+
133+
134+ def _prep_nodes (entities , summaries , graph ) -> pd .DataFrame :
135+ degrees_df = _compute_degree (graph )
136+ entities .drop (columns = ["description" ], inplace = True )
137+ nodes = (
138+ entities .merge (summaries , on = "name" , how = "left" )
139+ .merge (degrees_df , on = "name" )
140+ .drop_duplicates (subset = "name" )
141+ .rename (columns = {"name" : "title" , "source_id" : "text_unit_ids" })
142+ )
143+ nodes = nodes .loc [nodes ["title" ].notna ()].reset_index ()
144+ nodes ["human_readable_id" ] = nodes .index
145+ nodes ["id" ] = nodes ["human_readable_id" ].apply (lambda _x : str (uuid4 ()))
146+ return nodes
147+
148+
149+ def _prep_edges (relationships , summaries ) -> pd .DataFrame :
150+ edges = (
151+ relationships .drop (columns = ["description" ])
152+ .drop_duplicates (subset = ["source" , "target" ])
153+ .merge (summaries , on = ["source" , "target" ], how = "left" )
154+ .rename (columns = {"source_id" : "text_unit_ids" })
155+ )
156+ edges ["human_readable_id" ] = edges .index
157+ edges ["id" ] = edges ["human_readable_id" ].apply (lambda _x : str (uuid4 ()))
158+ return edges
159+
160+
161+ def _prep_communities (communities ) -> pd .DataFrame :
162+ base_communities = pd .DataFrame (
163+ communities , columns = cast (Any , ["level" , "community" , "title" ])
164+ )
165+ base_communities = base_communities .explode ("title" )
166+ base_communities ["community" ] = base_communities ["community" ].astype (int )
167+ return base_communities
168+
169+
170+ def _compute_degree (graph : nx .Graph ) -> pd .DataFrame :
171+ return pd .DataFrame ([
172+ {"name" : node , "degree" : int (degree )} for node , degree in graph .degree
173+ ]) # type: ignore
0 commit comments