Skip to content

Commit d17dfd0

Browse files
Graph collapse (#1464)
* Refactor graph creation * Semver * Spellcheck * Update integ pipeline * Fix cast * Improve pandas chaining * Cleaner apply * Use list comprehensions --------- Co-authored-by: Alonso Guevara <[email protected]>
1 parent 756f5c3 commit d17dfd0

File tree

61 files changed

+444
-1192
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+444
-1192
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "minor",
3+
"description": "Refactor graph creation."
4+
}

docs/config/yaml.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,6 @@ This is the base LLM configuration section. Other steps may override this config
230230

231231
- `embeddings` **bool** - Export embeddings snapshots to parquet.
232232
- `graphml` **bool** - Export graph snapshots to GraphML.
233-
- `raw_entities` **bool** - Export raw entity snapshots to JSON.
234-
- `top_level_nodes` **bool** - Export top-level-node snapshots to JSON.
235233
- `transient` **bool** - Export transient workflow tables snapshots to parquet.
236234

237235
### encoding_model

graphrag/config/create_graphrag_config.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -409,9 +409,6 @@ def hydrate_parallelization_params(
409409
):
410410
snapshots_model = SnapshotsConfig(
411411
graphml=reader.bool("graphml") or defs.SNAPSHOTS_GRAPHML,
412-
raw_entities=reader.bool("raw_entities") or defs.SNAPSHOTS_RAW_ENTITIES,
413-
top_level_nodes=reader.bool("top_level_nodes")
414-
or defs.SNAPSHOTS_TOP_LEVEL_NODES,
415412
embeddings=reader.bool("embeddings") or defs.SNAPSHOTS_EMBEDDINGS,
416413
transient=reader.bool("transient") or defs.SNAPSHOTS_TRANSIENT,
417414
)

graphrag/config/defaults.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,6 @@
7979
REPORTING_TYPE = ReportingType.file
8080
REPORTING_BASE_DIR = "logs"
8181
SNAPSHOTS_GRAPHML = False
82-
SNAPSHOTS_RAW_ENTITIES = False
83-
SNAPSHOTS_TOP_LEVEL_NODES = False
8482
SNAPSHOTS_EMBEDDINGS = False
8583
SNAPSHOTS_TRANSIENT = False
8684
STORAGE_BASE_DIR = "output"

graphrag/config/init_content.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,6 @@
115115
116116
snapshots:
117117
graphml: false
118-
raw_entities: false
119-
top_level_nodes: false
120118
embeddings: false
121119
transient: false
122120

graphrag/config/input_models/snapshots_config_input.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,4 @@ class SnapshotsConfigInput(TypedDict):
1111

1212
embeddings: NotRequired[bool | str | None]
1313
graphml: NotRequired[bool | str | None]
14-
raw_entities: NotRequired[bool | str | None]
15-
top_level_nodes: NotRequired[bool | str | None]
1614
transient: NotRequired[bool | str | None]

graphrag/config/models/snapshots_config.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,6 @@ class SnapshotsConfig(BaseModel):
1919
description="A flag indicating whether to take snapshots of GraphML.",
2020
default=defs.SNAPSHOTS_GRAPHML,
2121
)
22-
raw_entities: bool = Field(
23-
description="A flag indicating whether to take snapshots of raw entities.",
24-
default=defs.SNAPSHOTS_RAW_ENTITIES,
25-
)
26-
top_level_nodes: bool = Field(
27-
description="A flag indicating whether to take snapshots of top-level nodes.",
28-
default=defs.SNAPSHOTS_TOP_LEVEL_NODES,
29-
)
3022
transient: bool = Field(
3123
description="A flag indicating whether to take snapshots of transient tables.",
3224
default=defs.SNAPSHOTS_TRANSIENT,

graphrag/index/create_pipeline_config.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,6 @@ def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference
220220
config={
221221
"snapshot_graphml": settings.snapshots.graphml,
222222
"snapshot_transient": settings.snapshots.transient,
223-
"snapshot_raw_entities": settings.snapshots.raw_entities,
224223
"entity_extract": {
225224
**settings.entity_extraction.parallelization.model_dump(),
226225
"async_mode": settings.entity_extraction.async_mode,
@@ -236,11 +235,9 @@ def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference
236235
settings.root_dir,
237236
),
238237
},
239-
"embed_graph_enabled": settings.embed_graph.enabled,
240238
"cluster_graph": {
241239
"strategy": settings.cluster_graph.resolved_strategy()
242240
},
243-
"embed_graph": {"strategy": settings.embed_graph.resolved_strategy()},
244241
},
245242
),
246243
PipelineWorkflowReference(
@@ -255,7 +252,8 @@ def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference
255252
name=create_final_nodes,
256253
config={
257254
"layout_graph_enabled": settings.umap.enabled,
258-
"snapshot_top_level_nodes": settings.snapshots.top_level_nodes,
255+
"embed_graph_enabled": settings.embed_graph.enabled,
256+
"embed_graph": {"strategy": settings.embed_graph.resolved_strategy()},
259257
},
260258
),
261259
]

graphrag/index/flows/create_base_entity_graph.py

Lines changed: 97 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
"""All the steps to create the base entity graph."""
55

66
from typing import Any, cast
7+
from uuid import uuid4
78

9+
import networkx as nx
810
import pandas as pd
911
from datashaper import (
1012
AsyncType,
@@ -13,12 +15,10 @@
1315

1416
from graphrag.cache.pipeline_cache import PipelineCache
1517
from graphrag.index.operations.cluster_graph import cluster_graph
16-
from graphrag.index.operations.embed_graph import embed_graph
18+
from graphrag.index.operations.create_graph import create_graph
1719
from graphrag.index.operations.extract_entities import extract_entities
18-
from graphrag.index.operations.merge_graphs import merge_graphs
1920
from graphrag.index.operations.snapshot import snapshot
2021
from graphrag.index.operations.snapshot_graphml import snapshot_graphml
21-
from graphrag.index.operations.snapshot_rows import snapshot_rows
2222
from graphrag.index.operations.summarize_descriptions import (
2323
summarize_descriptions,
2424
)
@@ -30,23 +30,20 @@ async def create_base_entity_graph(
3030
callbacks: VerbCallbacks,
3131
cache: PipelineCache,
3232
storage: PipelineStorage,
33+
runtime_storage: PipelineStorage,
3334
clustering_strategy: dict[str, Any],
3435
extraction_strategy: dict[str, Any] | None = None,
3536
extraction_num_threads: int = 4,
3637
extraction_async_mode: AsyncType = AsyncType.AsyncIO,
3738
entity_types: list[str] | None = None,
38-
node_merge_config: dict[str, Any] | None = None,
39-
edge_merge_config: dict[str, Any] | None = None,
4039
summarization_strategy: dict[str, Any] | None = None,
4140
summarization_num_threads: int = 4,
42-
embedding_strategy: dict[str, Any] | None = None,
4341
snapshot_graphml_enabled: bool = False,
44-
snapshot_raw_entities_enabled: bool = False,
4542
snapshot_transient_enabled: bool = False,
46-
) -> pd.DataFrame:
43+
) -> None:
4744
"""All the steps to create the base entity graph."""
4845
# this returns a graph for each text unit, to be merged later
49-
entities, entity_graphs = await extract_entities(
46+
entity_dfs, relationship_dfs = await extract_entities(
5047
text_units,
5148
callbacks,
5249
cache,
@@ -55,89 +52,122 @@ async def create_base_entity_graph(
5552
strategy=extraction_strategy,
5653
async_mode=extraction_async_mode,
5754
entity_types=entity_types,
58-
to="entities",
5955
num_threads=extraction_num_threads,
6056
)
6157

62-
merged_graph = merge_graphs(
63-
entity_graphs,
64-
callbacks,
65-
node_operations=node_merge_config,
66-
edge_operations=edge_merge_config,
67-
)
58+
merged_entities = _merge_entities(entity_dfs)
59+
merged_relationships = _merge_relationships(relationship_dfs)
6860

69-
summarized = await summarize_descriptions(
70-
merged_graph,
61+
entity_summaries, relationship_summaries = await summarize_descriptions(
62+
merged_entities,
63+
merged_relationships,
7164
callbacks,
7265
cache,
7366
strategy=summarization_strategy,
7467
num_threads=summarization_num_threads,
7568
)
7669

77-
clustered = cluster_graph(
78-
summarized,
79-
callbacks,
80-
column="entity_graph",
70+
base_relationship_edges = _prep_edges(merged_relationships, relationship_summaries)
71+
72+
graph = create_graph(base_relationship_edges)
73+
74+
base_entity_nodes = _prep_nodes(merged_entities, entity_summaries, graph)
75+
76+
communities = cluster_graph(
77+
graph,
8178
strategy=clustering_strategy,
82-
to="clustered_graph",
83-
level_to="level",
8479
)
8580

86-
if embedding_strategy:
87-
clustered["embeddings"] = await embed_graph(
88-
clustered,
89-
callbacks,
90-
column="clustered_graph",
91-
strategy=embedding_strategy,
92-
)
81+
base_communities = _prep_communities(communities)
9382

94-
if snapshot_raw_entities_enabled:
95-
await snapshot(
96-
entities,
97-
name="raw_extracted_entities",
98-
storage=storage,
99-
formats=["json"],
100-
)
83+
await runtime_storage.set("base_entity_nodes", base_entity_nodes)
84+
await runtime_storage.set("base_relationship_edges", base_relationship_edges)
85+
await runtime_storage.set("base_communities", base_communities)
10186

10287
if snapshot_graphml_enabled:
88+
# todo: extract graphs at each level, and add in meta like descriptions
10389
await snapshot_graphml(
104-
merged_graph,
105-
name="merged_graph",
90+
graph,
91+
name="graph",
10692
storage=storage,
10793
)
108-
await snapshot_graphml(
109-
summarized,
110-
name="summarized_graph",
94+
95+
if snapshot_transient_enabled:
96+
await snapshot(
97+
base_entity_nodes,
98+
name="base_entity_nodes",
11199
storage=storage,
100+
formats=["parquet"],
112101
)
113-
await snapshot_rows(
114-
clustered,
115-
column="clustered_graph",
116-
base_name="clustered_graph",
102+
await snapshot(
103+
base_relationship_edges,
104+
name="base_relationship_edges",
117105
storage=storage,
118-
formats=[{"format": "text", "extension": "graphml"}],
106+
formats=["parquet"],
119107
)
120-
if embedding_strategy:
121-
await snapshot_rows(
122-
clustered,
123-
column="entity_graph",
124-
base_name="embedded_graph",
125-
storage=storage,
126-
formats=[{"format": "text", "extension": "graphml"}],
127-
)
128-
129-
final_columns = ["level", "clustered_graph"]
130-
if embedding_strategy:
131-
final_columns.append("embeddings")
132-
133-
output = cast(pd.DataFrame, clustered[final_columns])
134-
135-
if snapshot_transient_enabled:
136108
await snapshot(
137-
output,
138-
name="create_base_entity_graph",
109+
base_communities,
110+
name="base_communities",
139111
storage=storage,
140112
formats=["parquet"],
141113
)
142114

143-
return output
115+
116+
def _merge_entities(entity_dfs) -> pd.DataFrame:
117+
all_entities = pd.concat(entity_dfs, ignore_index=True)
118+
return (
119+
all_entities.groupby(["name", "type"], sort=False)
120+
.agg({"description": list, "source_id": list})
121+
.reset_index()
122+
)
123+
124+
125+
def _merge_relationships(relationship_dfs) -> pd.DataFrame:
126+
all_relationships = pd.concat(relationship_dfs, ignore_index=False)
127+
return (
128+
all_relationships.groupby(["source", "target"], sort=False)
129+
.agg({"description": list, "source_id": list, "weight": "sum"})
130+
.reset_index()
131+
)
132+
133+
134+
def _prep_nodes(entities, summaries, graph) -> pd.DataFrame:
135+
degrees_df = _compute_degree(graph)
136+
entities.drop(columns=["description"], inplace=True)
137+
nodes = (
138+
entities.merge(summaries, on="name", how="left")
139+
.merge(degrees_df, on="name")
140+
.drop_duplicates(subset="name")
141+
.rename(columns={"name": "title", "source_id": "text_unit_ids"})
142+
)
143+
nodes = nodes.loc[nodes["title"].notna()].reset_index()
144+
nodes["human_readable_id"] = nodes.index
145+
nodes["id"] = nodes["human_readable_id"].apply(lambda _x: str(uuid4()))
146+
return nodes
147+
148+
149+
def _prep_edges(relationships, summaries) -> pd.DataFrame:
150+
edges = (
151+
relationships.drop(columns=["description"])
152+
.drop_duplicates(subset=["source", "target"])
153+
.merge(summaries, on=["source", "target"], how="left")
154+
.rename(columns={"source_id": "text_unit_ids"})
155+
)
156+
edges["human_readable_id"] = edges.index
157+
edges["id"] = edges["human_readable_id"].apply(lambda _x: str(uuid4()))
158+
return edges
159+
160+
161+
def _prep_communities(communities) -> pd.DataFrame:
162+
base_communities = pd.DataFrame(
163+
communities, columns=cast(Any, ["level", "community", "title"])
164+
)
165+
base_communities = base_communities.explode("title")
166+
base_communities["community"] = base_communities["community"].astype(int)
167+
return base_communities
168+
169+
170+
def _compute_degree(graph: nx.Graph) -> pd.DataFrame:
171+
return pd.DataFrame([
172+
{"name": node, "degree": int(degree)} for node, degree in graph.degree
173+
]) # type: ignore

0 commit comments

Comments
 (0)