Skip to content

Commit c1c09ba

Browse files
Flow cleanup (#1510)
* Move snapshots out of flows into verbs * Move degree compute out of extract_graph * Move entity/relationship df merging into extract * Move "title" to extraction source * Move text_unit_ids agg closer to extraction * Move data definition * Update test data * Semver * Update smoke tests * Fix empty degree field and update smoke tests and verb data * Move extractors (#1516) * Consolidate graph embedding and umap * Consolidate claim extraction * Consolidate graph extractor * Move graph utils * Move summarizers * Semver --------- Co-authored-by: Alonso Guevara <[email protected]> * Fix syntax typo --------- Co-authored-by: Alonso Guevara <[email protected]>
1 parent d0543d1 commit c1c09ba

33 files changed

+142
-172
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Streamline flows."
4+
}

graphrag/index/flows/compute_communities.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,11 @@
99

1010
from graphrag.index.operations.cluster_graph import cluster_graph
1111
from graphrag.index.operations.create_graph import create_graph
12-
from graphrag.index.operations.snapshot import snapshot
13-
from graphrag.storage.pipeline_storage import PipelineStorage
1412

1513

16-
async def compute_communities(
14+
def compute_communities(
1715
base_relationship_edges: pd.DataFrame,
18-
storage: PipelineStorage,
1916
clustering_strategy: dict[str, Any],
20-
snapshot_transient_enabled: bool = False,
2117
) -> pd.DataFrame:
2218
"""All the steps to create the base entity graph."""
2319
graph = create_graph(base_relationship_edges)
@@ -32,12 +28,4 @@ async def compute_communities(
3228
).explode("title")
3329
base_communities["community"] = base_communities["community"].astype(int)
3430

35-
if snapshot_transient_enabled:
36-
await snapshot(
37-
base_communities,
38-
name="base_communities",
39-
storage=storage,
40-
formats=["parquet"],
41-
)
42-
4331
return base_communities

graphrag/index/flows/create_base_text_units.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,14 @@
1515
)
1616

1717
from graphrag.index.operations.chunk_text import chunk_text
18-
from graphrag.index.operations.snapshot import snapshot
1918
from graphrag.index.utils.hashing import gen_sha512_hash
20-
from graphrag.storage.pipeline_storage import PipelineStorage
2119

2220

23-
async def create_base_text_units(
21+
def create_base_text_units(
2422
documents: pd.DataFrame,
2523
callbacks: VerbCallbacks,
26-
storage: PipelineStorage,
2724
chunk_by_columns: list[str],
2825
chunk_strategy: dict[str, Any] | None = None,
29-
snapshot_transient_enabled: bool = False,
3026
) -> pd.DataFrame:
3127
"""All the steps to transform base text_units."""
3228
sort = documents.sort_values(by=["id"], ascending=[True])
@@ -74,19 +70,7 @@ async def create_base_text_units(
7470
# rename for downstream consumption
7571
chunked.rename(columns={"chunk": "text"}, inplace=True)
7672

77-
output = cast(
78-
"pd.DataFrame", chunked[chunked["text"].notna()].reset_index(drop=True)
79-
)
80-
81-
if snapshot_transient_enabled:
82-
await snapshot(
83-
output,
84-
name="create_base_text_units",
85-
storage=storage,
86-
formats=["parquet"],
87-
)
88-
89-
return output
73+
return cast("pd.DataFrame", chunked[chunked["text"].notna()].reset_index(drop=True))
9074

9175

9276
# TODO: would be nice to inline this completely in the main method with pandas

graphrag/index/flows/create_final_nodes.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
VerbCallbacks,
1111
)
1212

13+
from graphrag.index.operations.compute_degree import compute_degree
1314
from graphrag.index.operations.create_graph import create_graph
1415
from graphrag.index.operations.embed_graph.embed_graph import embed_graph
1516
from graphrag.index.operations.layout_graph.layout_graph import layout_graph
@@ -37,15 +38,19 @@ def create_final_nodes(
3738
layout_strategy,
3839
embeddings=graph_embeddings,
3940
)
40-
nodes = base_entity_nodes.merge(
41-
layout, left_on="title", right_on="label", how="left"
42-
)
4341

44-
joined = nodes.merge(base_communities, on="title", how="left")
45-
joined["level"] = joined["level"].fillna(0).astype(int)
46-
joined["community"] = joined["community"].fillna(-1).astype(int)
42+
degrees = compute_degree(graph)
4743

48-
return joined.loc[
44+
nodes = (
45+
base_entity_nodes.merge(layout, left_on="title", right_on="label", how="left")
46+
.merge(degrees, on="title", how="left")
47+
.merge(base_communities, on="title", how="left")
48+
)
49+
nodes["level"] = nodes["level"].fillna(0).astype(int)
50+
nodes["community"] = nodes["community"].fillna(-1).astype(int)
51+
# disconnected nodes and those with no community even at level 0 can be missing degree
52+
nodes["degree"] = nodes["degree"].fillna(0).astype(int)
53+
return nodes.loc[
4954
:,
5055
[
5156
"id",

graphrag/index/flows/create_final_relationships.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,25 @@
55

66
import pandas as pd
77

8+
from graphrag.index.operations.compute_degree import compute_degree
89
from graphrag.index.operations.compute_edge_combined_degree import (
910
compute_edge_combined_degree,
1011
)
12+
from graphrag.index.operations.create_graph import create_graph
1113

1214

1315
def create_final_relationships(
1416
base_relationship_edges: pd.DataFrame,
15-
base_entity_nodes: pd.DataFrame,
1617
) -> pd.DataFrame:
1718
"""All the steps to transform final relationships."""
1819
relationships = base_relationship_edges
20+
21+
graph = create_graph(base_relationship_edges)
22+
degrees = compute_degree(graph)
23+
1924
relationships["combined_degree"] = compute_edge_combined_degree(
2025
relationships,
21-
base_entity_nodes,
26+
degrees,
2227
node_name_column="title",
2328
node_degree_column="degree",
2429
edge_source_column="source",

graphrag/index/flows/extract_graph.py

Lines changed: 13 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -6,41 +6,33 @@
66
from typing import Any
77
from uuid import uuid4
88

9-
import networkx as nx
109
import pandas as pd
1110
from datashaper import (
1211
AsyncType,
1312
VerbCallbacks,
1413
)
1514

1615
from graphrag.cache.pipeline_cache import PipelineCache
17-
from graphrag.index.operations.create_graph import create_graph
1816
from graphrag.index.operations.extract_entities import extract_entities
19-
from graphrag.index.operations.snapshot import snapshot
20-
from graphrag.index.operations.snapshot_graphml import snapshot_graphml
2117
from graphrag.index.operations.summarize_descriptions import (
2218
summarize_descriptions,
2319
)
24-
from graphrag.storage.pipeline_storage import PipelineStorage
2520

2621

2722
async def extract_graph(
2823
text_units: pd.DataFrame,
2924
callbacks: VerbCallbacks,
3025
cache: PipelineCache,
31-
storage: PipelineStorage,
3226
extraction_strategy: dict[str, Any] | None = None,
3327
extraction_num_threads: int = 4,
3428
extraction_async_mode: AsyncType = AsyncType.AsyncIO,
3529
entity_types: list[str] | None = None,
3630
summarization_strategy: dict[str, Any] | None = None,
3731
summarization_num_threads: int = 4,
38-
snapshot_graphml_enabled: bool = False,
39-
snapshot_transient_enabled: bool = False,
4032
) -> tuple[pd.DataFrame, pd.DataFrame]:
4133
"""All the steps to create the base entity graph."""
4234
# this returns a graph for each text unit, to be merged later
43-
entity_dfs, relationship_dfs = await extract_entities(
35+
entities, relationships = await extract_entities(
4436
text_units,
4537
callbacks,
4638
cache,
@@ -52,87 +44,38 @@ async def extract_graph(
5244
num_threads=extraction_num_threads,
5345
)
5446

55-
if not _validate_data(entity_dfs):
47+
if not _validate_data(entities):
5648
error_msg = "Entity Extraction failed. No entities detected during extraction."
5749
callbacks.error(error_msg)
5850
raise ValueError(error_msg)
5951

60-
if not _validate_data(relationship_dfs):
52+
if not _validate_data(relationships):
6153
error_msg = (
6254
"Entity Extraction failed. No relationships detected during extraction."
6355
)
6456
callbacks.error(error_msg)
6557
raise ValueError(error_msg)
6658

67-
merged_entities = _merge_entities(entity_dfs)
68-
merged_relationships = _merge_relationships(relationship_dfs)
69-
7059
entity_summaries, relationship_summaries = await summarize_descriptions(
71-
merged_entities,
72-
merged_relationships,
60+
entities,
61+
relationships,
7362
callbacks,
7463
cache,
7564
strategy=summarization_strategy,
7665
num_threads=summarization_num_threads,
7766
)
7867

79-
base_relationship_edges = _prep_edges(merged_relationships, relationship_summaries)
80-
81-
graph = create_graph(base_relationship_edges)
82-
83-
base_entity_nodes = _prep_nodes(merged_entities, entity_summaries, graph)
84-
85-
if snapshot_graphml_enabled:
86-
# todo: extract graphs at each level, and add in meta like descriptions
87-
await snapshot_graphml(
88-
graph,
89-
name="graph",
90-
storage=storage,
91-
)
68+
base_relationship_edges = _prep_edges(relationships, relationship_summaries)
9269

93-
if snapshot_transient_enabled:
94-
await snapshot(
95-
base_entity_nodes,
96-
name="base_entity_nodes",
97-
storage=storage,
98-
formats=["parquet"],
99-
)
100-
await snapshot(
101-
base_relationship_edges,
102-
name="base_relationship_edges",
103-
storage=storage,
104-
formats=["parquet"],
105-
)
70+
base_entity_nodes = _prep_nodes(entities, entity_summaries)
10671

10772
return (base_entity_nodes, base_relationship_edges)
10873

10974

110-
def _merge_entities(entity_dfs) -> pd.DataFrame:
111-
all_entities = pd.concat(entity_dfs, ignore_index=True)
112-
return (
113-
all_entities.groupby(["name", "type"], sort=False)
114-
.agg({"description": list, "source_id": list})
115-
.reset_index()
116-
)
117-
118-
119-
def _merge_relationships(relationship_dfs) -> pd.DataFrame:
120-
all_relationships = pd.concat(relationship_dfs, ignore_index=False)
121-
return (
122-
all_relationships.groupby(["source", "target"], sort=False)
123-
.agg({"description": list, "source_id": list, "weight": "sum"})
124-
.reset_index()
125-
)
126-
127-
128-
def _prep_nodes(entities, summaries, graph) -> pd.DataFrame:
129-
degrees_df = _compute_degree(graph)
75+
def _prep_nodes(entities, summaries) -> pd.DataFrame:
13076
entities.drop(columns=["description"], inplace=True)
131-
nodes = (
132-
entities.merge(summaries, on="name", how="left")
133-
.merge(degrees_df, on="name")
134-
.drop_duplicates(subset="name")
135-
.rename(columns={"name": "title", "source_id": "text_unit_ids"})
77+
nodes = entities.merge(summaries, on="title", how="left").drop_duplicates(
78+
subset="title"
13679
)
13780
nodes = nodes.loc[nodes["title"].notna()].reset_index()
13881
nodes["human_readable_id"] = nodes.index
@@ -145,22 +88,12 @@ def _prep_edges(relationships, summaries) -> pd.DataFrame:
14588
relationships.drop(columns=["description"])
14689
.drop_duplicates(subset=["source", "target"])
14790
.merge(summaries, on=["source", "target"], how="left")
148-
.rename(columns={"source_id": "text_unit_ids"})
14991
)
15092
edges["human_readable_id"] = edges.index
15193
edges["id"] = edges["human_readable_id"].apply(lambda _x: str(uuid4()))
15294
return edges
15395

15496

155-
def _compute_degree(graph: nx.Graph) -> pd.DataFrame:
156-
return pd.DataFrame([
157-
{"name": node, "degree": int(degree)}
158-
for node, degree in graph.degree # type: ignore
159-
])
160-
161-
162-
def _validate_data(df_list: list[pd.DataFrame]) -> bool:
163-
"""Validate that the dataframe list is valid. At least one dataframe must contain data."""
164-
return any(
165-
len(df) > 0 for df in df_list
166-
) # Check for len, not .empty, as the dfs have schemas in some cases
97+
def _validate_data(df: pd.DataFrame) -> bool:
98+
"""Validate that the dataframe has data."""
99+
return len(df) > 0

graphrag/index/flows/generate_text_embeddings.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,8 @@ async def _run_and_snapshot_embeddings(
129129
strategy=text_embed_config["strategy"],
130130
)
131131

132-
data = data.loc[:, ["id", "embedding"]]
133-
134132
if snapshot_embeddings_enabled is True:
133+
data = data.loc[:, ["id", "embedding"]]
135134
await snapshot(
136135
data,
137136
name=f"embeddings.{name}",
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
"""A module containing create_graph definition."""
5+
6+
import networkx as nx
7+
import pandas as pd
8+
9+
10+
def compute_degree(graph: nx.Graph) -> pd.DataFrame:
11+
"""Create a new DataFrame with the degree of each node in the graph."""
12+
return pd.DataFrame([
13+
{"title": node, "degree": int(degree)}
14+
for node, degree in graph.degree # type: ignore
15+
])

graphrag/index/operations/extract_entities/extract_entities.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ async def extract_entities(
3737
async_mode: AsyncType = AsyncType.AsyncIO,
3838
entity_types=DEFAULT_ENTITY_TYPES,
3939
num_threads: int = 4,
40-
) -> tuple[list[pd.DataFrame], list[pd.DataFrame]]:
40+
) -> tuple[pd.DataFrame, pd.DataFrame]:
4141
"""
4242
Extract entities from a piece of text.
4343
@@ -138,7 +138,10 @@ async def run_strategy(row):
138138
entity_dfs.append(pd.DataFrame(result[0]))
139139
relationship_dfs.append(pd.DataFrame(result[1]))
140140

141-
return (entity_dfs, relationship_dfs)
141+
entities = _merge_entities(entity_dfs)
142+
relationships = _merge_relationships(relationship_dfs)
143+
144+
return (entities, relationships)
142145

143146

144147
def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStrategy:
@@ -162,3 +165,25 @@ def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStr
162165
case _:
163166
msg = f"Unknown strategy: {strategy_type}"
164167
raise ValueError(msg)
168+
169+
170+
def _merge_entities(entity_dfs) -> pd.DataFrame:
171+
all_entities = pd.concat(entity_dfs, ignore_index=True)
172+
return (
173+
all_entities.groupby(["title", "type"], sort=False)
174+
.agg(description=("description", list), text_unit_ids=("source_id", list))
175+
.reset_index()
176+
)
177+
178+
179+
def _merge_relationships(relationship_dfs) -> pd.DataFrame:
180+
all_relationships = pd.concat(relationship_dfs, ignore_index=False)
181+
return (
182+
all_relationships.groupby(["source", "target"], sort=False)
183+
.agg(
184+
description=("description", list),
185+
text_unit_ids=("source_id", list),
186+
weight=("weight", "sum"),
187+
)
188+
.reset_index()
189+
)

graphrag/index/operations/extract_entities/graph_intelligence_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ async def run_extract_entities(
106106
)
107107

108108
entities = [
109-
({"name": item[0], **(item[1] or {})})
109+
({"title": item[0], **(item[1] or {})})
110110
for item in graph.nodes(data=True)
111111
if item is not None
112112
]

0 commit comments

Comments
 (0)