Skip to content

Commit 634e3ed

Browse files
authored
Transient entity graph (#1349)
* Make base_entity_graph transient * Add transient snapshots * Semver * Fix unit test * Fix smoke tests
1 parent 17658c5 commit 634e3ed

34 files changed

+209
-96
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Transient entity graph and snapshotting."
4+
}

docs/config/env_vars.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,11 +199,13 @@ This section controls the reporting mechanism used by the pipeline, for common e
199199

200200
## Data Snapshotting
201201

202-
| Parameter | Description | Type | Required or Optional | Default |
203-
| ----------------------------------- | ------------------------------------------- | ------ | -------------------- | ------- |
204-
| `GRAPHRAG_SNAPSHOT_GRAPHML` | Whether to enable GraphML snapshots. | `bool` | optional | False |
205-
| `GRAPHRAG_SNAPSHOT_RAW_ENTITIES` | Whether to enable raw entity snapshots. | `bool` | optional | False |
206-
| `GRAPHRAG_SNAPSHOT_TOP_LEVEL_NODES` | Whether to enable top-level node snapshots. | `bool` | optional | False |
202+
| Parameter | Description | Type | Required or Optional | Default |
203+
| -------------------------------------- | ----------------------------------------------- | ------ | -------------------- | ------- |
204+
| `GRAPHRAG_SNAPSHOT_EMBEDDINGS` | Whether to enable embeddings snapshots. | `bool` | optional | False |
205+
| `GRAPHRAG_SNAPSHOT_GRAPHML` | Whether to enable GraphML snapshots. | `bool` | optional | False |
206+
| `GRAPHRAG_SNAPSHOT_RAW_ENTITIES` | Whether to enable raw entity snapshots. | `bool` | optional | False |
207+
| `GRAPHRAG_SNAPSHOT_TOP_LEVEL_NODES` | Whether to enable top-level node snapshots. | `bool` | optional | False |
208+
| `GRAPHRAG_SNAPSHOT_TRANSIENT` | Whether to enable transient table snapshots. | `bool` | optional | False |
207209

208210
# Miscellaneous Settings
209211

docs/config/json_yaml.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,11 @@ This is the base LLM configuration section. Other steps may override this config
216216

217217
### Fields
218218

219-
- `graphml` **bool** - Emit graphml snapshots.
220-
- `raw_entities` **bool** - Emit raw entity snapshots.
221-
- `top_level_nodes` **bool** - Emit top-level-node snapshots.
219+
- `embeddings` **bool** - Emit embeddings snapshots to parquet.
220+
- `graphml` **bool** - Emit graph snapshots to GraphML.
221+
- `raw_entities` **bool** - Emit raw entity snapshots to JSON.
222+
- `top_level_nodes` **bool** - Emit top-level-node snapshots to JSON.
223+
- `transient` **bool** - Emit transient workflow tables snapshots to parquet.
222224

223225
## encoding_model
224226

graphrag/config/create_graphrag_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,7 @@ def hydrate_parallelization_params(
415415
top_level_nodes=reader.bool("top_level_nodes")
416416
or defs.SNAPSHOTS_TOP_LEVEL_NODES,
417417
embeddings=reader.bool("embeddings") or defs.SNAPSHOTS_EMBEDDINGS,
418+
transient=reader.bool("transient") or defs.SNAPSHOTS_TRANSIENT,
418419
)
419420
with reader.envvar_prefix(Section.umap), reader.use(values.get("umap")):
420421
umap_model = UmapConfig(

graphrag/config/defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
SNAPSHOTS_RAW_ENTITIES = False
8484
SNAPSHOTS_TOP_LEVEL_NODES = False
8585
SNAPSHOTS_EMBEDDINGS = False
86+
SNAPSHOTS_TRANSIENT = False
8687
STORAGE_BASE_DIR = "output"
8788
STORAGE_TYPE = StorageType.file
8889
SUMMARIZE_DESCRIPTIONS_MAX_LENGTH = 500

graphrag/config/input_models/snapshots_config_input.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
class SnapshotsConfigInput(TypedDict):
1010
"""Configuration section for snapshots."""
1111

12+
embeddings: NotRequired[bool | str | None]
1213
graphml: NotRequired[bool | str | None]
1314
raw_entities: NotRequired[bool | str | None]
1415
top_level_nodes: NotRequired[bool | str | None]
16+
transient: NotRequired[bool | str | None]

graphrag/config/models/snapshots_config.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
class SnapshotsConfig(BaseModel):
1212
"""Configuration section for snapshots."""
1313

14+
embeddings: bool = Field(
15+
description="A flag indicating whether to take snapshots of embeddings.",
16+
default=defs.SNAPSHOTS_EMBEDDINGS,
17+
)
1418
graphml: bool = Field(
1519
description="A flag indicating whether to take snapshots of GraphML.",
1620
default=defs.SNAPSHOTS_GRAPHML,
@@ -23,7 +27,7 @@ class SnapshotsConfig(BaseModel):
2327
description="A flag indicating whether to take snapshots of top-level nodes.",
2428
default=defs.SNAPSHOTS_TOP_LEVEL_NODES,
2529
)
26-
embeddings: bool = Field(
27-
description="A flag indicating whether to take snapshots of embeddings.",
28-
default=defs.SNAPSHOTS_EMBEDDINGS,
30+
transient: bool = Field(
31+
description="A flag indicating whether to take snapshots of transient tables.",
32+
default=defs.SNAPSHOTS_TRANSIENT,
2933
)

graphrag/index/create_pipeline_config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def _text_unit_workflows(
171171
PipelineWorkflowReference(
172172
name=create_base_text_units,
173173
config={
174+
"snapshot_transient": settings.snapshots.transient,
174175
"chunk_by": settings.chunks.group_by_columns,
175176
"text_chunk": {
176177
"strategy": settings.chunks.resolved_strategy(
@@ -215,7 +216,9 @@ def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference
215216
PipelineWorkflowReference(
216217
name=create_base_entity_graph,
217218
config={
218-
"graphml_snapshot": settings.snapshots.graphml,
219+
"snapshot_graphml": settings.snapshots.graphml,
220+
"snapshot_transient": settings.snapshots.transient,
221+
"snapshot_raw_entities": settings.snapshots.raw_entities,
219222
"entity_extract": {
220223
**settings.entity_extraction.parallelization.model_dump(),
221224
"async_mode": settings.entity_extraction.async_mode,

graphrag/index/flows/create_base_entity_graph.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@ async def create_base_entity_graph(
4242
summarization_strategy: dict[str, Any] | None = None,
4343
summarization_num_threads: int = 4,
4444
embedding_strategy: dict[str, Any] | None = None,
45-
graphml_snapshot_enabled: bool = False,
46-
raw_entity_snapshot_enabled: bool = False,
45+
snapshot_graphml_enabled: bool = False,
46+
snapshot_raw_entities_enabled: bool = False,
47+
snapshot_transient_enabled: bool = False,
4748
) -> pd.DataFrame:
4849
"""All the steps to create the base entity graph."""
4950
# this returns a graph for each text unit, to be merged later
@@ -92,15 +93,15 @@ async def create_base_entity_graph(
9293
strategy=embedding_strategy,
9394
)
9495

95-
if raw_entity_snapshot_enabled:
96+
if snapshot_raw_entities_enabled:
9697
await snapshot(
9798
entities,
9899
name="raw_extracted_entities",
99100
storage=storage,
100101
formats=["json"],
101102
)
102103

103-
if graphml_snapshot_enabled:
104+
if snapshot_graphml_enabled:
104105
await snapshot_graphml(
105106
merged_graph,
106107
name="merged_graph",
@@ -131,4 +132,14 @@ async def create_base_entity_graph(
131132
if embedding_strategy:
132133
final_columns.append("embeddings")
133134

134-
return cast(pd.DataFrame, clustered[final_columns])
135+
output = cast(pd.DataFrame, clustered[final_columns])
136+
137+
if snapshot_transient_enabled:
138+
await snapshot(
139+
output,
140+
name="create_base_entity_graph",
141+
storage=storage,
142+
formats=["parquet"],
143+
)
144+
145+
return output

graphrag/index/flows/create_base_text_units.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,20 @@
1515
)
1616

1717
from graphrag.index.operations.chunk_text import chunk_text
18+
from graphrag.index.operations.snapshot import snapshot
19+
from graphrag.index.storage import PipelineStorage
1820
from graphrag.index.utils import gen_md5_hash
1921

2022

21-
def create_base_text_units(
23+
async def create_base_text_units(
2224
documents: pd.DataFrame,
2325
callbacks: VerbCallbacks,
26+
storage: PipelineStorage,
2427
chunk_column_name: str,
2528
n_tokens_column_name: str,
2629
chunk_by_columns: list[str],
2730
chunk_strategy: dict[str, Any] | None = None,
31+
snapshot_transient_enabled: bool = False,
2832
) -> pd.DataFrame:
2933
"""All the steps to transform base text_units."""
3034
sort = documents.sort_values(by=["id"], ascending=[True])
@@ -73,10 +77,20 @@ def create_base_text_units(
7377
)
7478
chunked["id"] = chunked["chunk_id"]
7579

76-
return cast(
80+
output = cast(
7781
pd.DataFrame, chunked[chunked[chunk_column_name].notna()].reset_index(drop=True)
7882
)
7983

84+
if snapshot_transient_enabled:
85+
await snapshot(
86+
output,
87+
name="create_base_text_units",
88+
storage=storage,
89+
formats=["parquet"],
90+
)
91+
92+
return output
93+
8094

8195
# TODO: would be nice to inline this completely in the main method with pandas
8296
def _aggregate_df(

0 commit comments

Comments
 (0)