Skip to content

Commit 96219a2

Browse files
authored
Register workflows (#1691)
* Add workflow registration * Add ability to mutate config by workflows * Separate graph finalization * Separate graph pruning * Semver * Update tests * Update smoke tests * Fix iterrows on create_graph * Remove prune_graph from llm construction * Update test data * Remove prune_graph from smoke tests
1 parent 981fd31 commit 96219a2

37 files changed

+369
-166
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Separates graph pruning for differential usage."
4+
}

graphrag/api/index.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
from graphrag.config.enums import CacheType, IndexingMethod
1717
from graphrag.config.models.graph_rag_config import GraphRagConfig
1818
from graphrag.index.run.run_pipeline import run_pipeline
19-
from graphrag.index.typing import PipelineRunResult
20-
from graphrag.index.workflows.factory import create_pipeline
19+
from graphrag.index.typing import PipelineRunResult, WorkflowFunction
20+
from graphrag.index.workflows.factory import PipelineFactory
2121
from graphrag.logger.base import ProgressLogger
2222

2323
log = logging.getLogger(__name__)
@@ -63,7 +63,7 @@ async def build_index(
6363
if memory_profile:
6464
log.warning("New pipeline does not yet support memory profiling.")
6565

66-
pipeline = create_pipeline(config, method)
66+
pipeline = PipelineFactory.create_pipeline(config, method)
6767

6868
async for output in run_pipeline(
6969
pipeline,
@@ -82,3 +82,8 @@ async def build_index(
8282
progress_logger.info(str(output.result))
8383

8484
return outputs
85+
86+
87+
def register_workflow_function(name: str, workflow: WorkflowFunction):
88+
"""Register a custom workflow function. You can then include the name in the settings.yaml workflows list."""
89+
PipelineFactory.register(name, workflow)

graphrag/index/flows/extract_graph.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,9 @@
1010
from graphrag.cache.pipeline_cache import PipelineCache
1111
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
1212
from graphrag.config.enums import AsyncType
13-
from graphrag.config.models.embed_graph_config import EmbedGraphConfig
1413
from graphrag.index.operations.extract_graph.extract_graph import (
1514
extract_graph as extractor,
1615
)
17-
from graphrag.index.operations.finalize_entities import finalize_entities
18-
from graphrag.index.operations.finalize_relationships import finalize_relationships
1916
from graphrag.index.operations.summarize_descriptions import (
2017
summarize_descriptions,
2118
)
@@ -31,8 +28,6 @@ async def extract_graph(
3128
entity_types: list[str] | None = None,
3229
summarization_strategy: dict[str, Any] | None = None,
3330
summarization_num_threads: int = 4,
34-
embed_config: EmbedGraphConfig | None = None,
35-
layout_enabled: bool = False,
3631
) -> tuple[pd.DataFrame, pd.DataFrame]:
3732
"""All the steps to create the base entity graph."""
3833
# this returns a graph for each text unit, to be merged later
@@ -76,11 +71,7 @@ async def extract_graph(
7671
extracted_entities.drop(columns=["description"], inplace=True)
7772
entities = extracted_entities.merge(entity_summaries, on="title", how="left")
7873

79-
final_entities = finalize_entities(
80-
entities, relationships, callbacks, embed_config, layout_enabled
81-
)
82-
final_relationships = finalize_relationships(relationships)
83-
return (final_entities, final_relationships)
74+
return (entities, relationships)
8475

8576

8677
def _validate_data(df: pd.DataFrame) -> bool:

graphrag/index/flows/extract_graph_nlp.py

Lines changed: 4 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,17 @@
66
import pandas as pd
77

88
from graphrag.cache.pipeline_cache import PipelineCache
9-
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
10-
from graphrag.config.models.embed_graph_config import EmbedGraphConfig
119
from graphrag.config.models.extract_graph_nlp_config import ExtractGraphNLPConfig
12-
from graphrag.config.models.prune_graph_config import PruneGraphConfig
1310
from graphrag.index.operations.build_noun_graph.build_noun_graph import build_noun_graph
1411
from graphrag.index.operations.build_noun_graph.np_extractors.factory import (
1512
create_noun_phrase_extractor,
1613
)
17-
from graphrag.index.operations.create_graph import create_graph
18-
from graphrag.index.operations.finalize_entities import finalize_entities
19-
from graphrag.index.operations.finalize_relationships import finalize_relationships
20-
from graphrag.index.operations.graph_to_dataframes import graph_to_dataframes
21-
from graphrag.index.operations.prune_graph import prune_graph
2214

2315

2416
async def extract_graph_nlp(
2517
text_units: pd.DataFrame,
26-
callbacks: WorkflowCallbacks,
2718
cache: PipelineCache,
2819
extraction_config: ExtractGraphNLPConfig,
29-
pruning_config: PruneGraphConfig,
30-
embed_config: EmbedGraphConfig | None = None,
31-
layout_enabled: bool = False,
3220
) -> tuple[pd.DataFrame, pd.DataFrame]:
3321
"""All the steps to create the base entity graph."""
3422
text_analyzer_config = extraction_config.text_analyzer
@@ -41,37 +29,9 @@ async def extract_graph_nlp(
4129
cache=cache,
4230
)
4331

44-
# create a temporary graph to prune, then turn it back into dataframes
45-
graph = create_graph(extracted_edges, edge_attr=["weight"], nodes=extracted_nodes)
46-
pruned = prune_graph(
47-
graph,
48-
min_node_freq=pruning_config.min_node_freq,
49-
max_node_freq_std=pruning_config.max_node_freq_std,
50-
min_node_degree=pruning_config.min_node_degree,
51-
max_node_degree_std=pruning_config.max_node_degree_std,
52-
min_edge_weight_pct=pruning_config.min_edge_weight_pct,
53-
remove_ego_nodes=pruning_config.remove_ego_nodes,
54-
lcc_only=pruning_config.lcc_only,
55-
)
56-
57-
pruned_nodes, pruned_edges = graph_to_dataframes(
58-
pruned, node_columns=["title"], edge_columns=["source", "target"]
59-
)
60-
61-
# subset the full nodes and edges to only include the pruned remainders
62-
joined_nodes = pruned_nodes.merge(extracted_nodes, on="title", how="inner")
63-
joined_edges = pruned_edges.merge(
64-
extracted_edges, on=["source", "target"], how="inner"
65-
)
66-
6732
# add in any other columns required by downstream workflows
68-
joined_nodes["type"] = "NOUN PHRASE"
69-
joined_nodes["description"] = ""
33+
extracted_nodes["type"] = "NOUN PHRASE"
34+
extracted_nodes["description"] = ""
35+
extracted_edges["description"] = ""
7036

71-
joined_edges["description"] = ""
72-
73-
final_entities = finalize_entities(
74-
joined_nodes, joined_edges, callbacks, embed_config, layout_enabled
75-
)
76-
final_relationships = finalize_relationships(joined_edges)
77-
return (final_entities, final_relationships)
37+
return (extracted_nodes, extracted_edges)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
"""All the steps to create the base entity graph."""
5+
6+
import pandas as pd
7+
8+
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
9+
from graphrag.config.models.embed_graph_config import EmbedGraphConfig
10+
from graphrag.index.operations.finalize_entities import finalize_entities
11+
from graphrag.index.operations.finalize_relationships import finalize_relationships
12+
13+
14+
def finalize_graph(
15+
entities: pd.DataFrame,
16+
relationships: pd.DataFrame,
17+
callbacks: WorkflowCallbacks,
18+
embed_config: EmbedGraphConfig | None = None,
19+
layout_enabled: bool = False,
20+
) -> tuple[pd.DataFrame, pd.DataFrame]:
21+
"""All the steps to finalize the entity and relationship formats."""
22+
final_entities = finalize_entities(
23+
entities, relationships, callbacks, embed_config, layout_enabled
24+
)
25+
final_relationships = finalize_relationships(relationships)
26+
return (final_entities, final_relationships)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
"""Prune a full graph based on graph statistics."""
5+
6+
import pandas as pd
7+
8+
from graphrag.config.models.prune_graph_config import PruneGraphConfig
9+
from graphrag.index.operations.create_graph import create_graph
10+
from graphrag.index.operations.graph_to_dataframes import graph_to_dataframes
11+
from graphrag.index.operations.prune_graph import prune_graph as prune_graph_operation
12+
13+
14+
def prune_graph(
15+
entities: pd.DataFrame,
16+
relationships: pd.DataFrame,
17+
pruning_config: PruneGraphConfig,
18+
) -> tuple[pd.DataFrame, pd.DataFrame]:
19+
"""Prune a full graph based on graph statistics."""
20+
# create a temporary graph to prune, then turn it back into dataframes
21+
graph = create_graph(relationships, edge_attr=["weight"], nodes=entities)
22+
pruned = prune_graph_operation(
23+
graph,
24+
min_node_freq=pruning_config.min_node_freq,
25+
max_node_freq_std=pruning_config.max_node_freq_std,
26+
min_node_degree=pruning_config.min_node_degree,
27+
max_node_degree_std=pruning_config.max_node_degree_std,
28+
min_edge_weight_pct=pruning_config.min_edge_weight_pct,
29+
remove_ego_nodes=pruning_config.remove_ego_nodes,
30+
lcc_only=pruning_config.lcc_only,
31+
)
32+
33+
pruned_nodes, pruned_edges = graph_to_dataframes(
34+
pruned, node_columns=["title"], edge_columns=["source", "target"]
35+
)
36+
37+
# subset the full nodes and edges to only include the pruned remainders
38+
subset_entities = pruned_nodes.merge(entities, on="title", how="inner")
39+
subset_relationships = pruned_edges.merge(
40+
relationships, on=["source", "target"], how="inner"
41+
)
42+
43+
return (subset_entities, subset_relationships)

graphrag/index/operations/build_noun_graph/build_noun_graph.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ async def _extract_nodes(
4444
Extract initial nodes and edges from text units.
4545
4646
Input: text unit df with schema [id, text, document_id]
47-
Returns a dataframe with schema [id, title, freq, text_unit_ids].
47+
Returns a dataframe with schema [id, title, frequency, text_unit_ids].
4848
"""
4949
cache = cache or NoopPipelineCache()
5050
cache = cache.child("extract_noun_phrases")
@@ -76,9 +76,9 @@ async def extract(row):
7676
noun_node_df.groupby("title").agg({"text_unit_id": list}).reset_index()
7777
)
7878
grouped_node_df = grouped_node_df.rename(columns={"text_unit_id": "text_unit_ids"})
79-
grouped_node_df["freq"] = grouped_node_df["text_unit_ids"].apply(len)
80-
grouped_node_df = grouped_node_df[["title", "freq", "text_unit_ids"]]
81-
return grouped_node_df.loc[:, ["title", "freq", "text_unit_ids"]]
79+
grouped_node_df["frequency"] = grouped_node_df["text_unit_ids"].apply(len)
80+
grouped_node_df = grouped_node_df[["title", "frequency", "text_unit_ids"]]
81+
return grouped_node_df.loc[:, ["title", "frequency", "text_unit_ids"]]
8282

8383

8484
def _extract_edges(
@@ -89,7 +89,7 @@ def _extract_edges(
8989
Extract edges from nodes.
9090
9191
Nodes appear in the same text unit are connected.
92-
Input: nodes_df with schema [id, title, freq, text_unit_ids]
92+
Input: nodes_df with schema [id, title, frequency, text_unit_ids]
9393
Returns: edges_df with schema [source, target, weight, text_unit_ids]
9494
"""
9595
text_units_df = nodes_df.explode("text_unit_ids")
@@ -156,7 +156,7 @@ def _calculate_pmi_edge_weights(
156156
nodes_df: pd.DataFrame,
157157
edges_df: pd.DataFrame,
158158
node_name_col="title",
159-
node_freq_col="freq",
159+
node_freq_col="frequency",
160160
edge_weight_col="weight",
161161
edge_source_col="source",
162162
edge_target_col="target",

graphrag/index/operations/create_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,6 @@ def create_graph(
1818

1919
if nodes is not None:
2020
nodes.set_index(node_id, inplace=True)
21-
graph.add_nodes_from(nodes.to_dict("index").items())
21+
graph.add_nodes_from((n, dict(d)) for n, d in nodes.iterrows())
2222

2323
return graph

graphrag/index/operations/extract_graph/extract_graph.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,11 @@ def _merge_entities(entity_dfs) -> pd.DataFrame:
154154
all_entities = pd.concat(entity_dfs, ignore_index=True)
155155
return (
156156
all_entities.groupby(["title", "type"], sort=False)
157-
.agg(description=("description", list), text_unit_ids=("source_id", list))
157+
.agg(
158+
description=("description", list),
159+
text_unit_ids=("source_id", list),
160+
frequency=("source_id", "count"),
161+
)
158162
.reset_index()
159163
)
160164

graphrag/index/operations/finalize_entities.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def finalize_entities(
5959
"type",
6060
"description",
6161
"text_unit_ids",
62+
"frequency",
6263
"degree",
6364
"x",
6465
"y",

0 commit comments

Comments
 (0)