Skip to content

Commit 907d271

Browse files
authored
Fix recursive report generation (#1669)
1 parent 53b06aa commit 907d271

File tree

6 files changed

+49
-51
lines changed

6 files changed

+49
-51
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Fix report generation recursion."
4+
}

graphrag/index/flows/create_final_community_reports.py

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,9 @@
1313
from graphrag.config.enums import AsyncType
1414
from graphrag.index.operations.summarize_communities import (
1515
prepare_community_reports,
16-
restore_community_hierarchy,
1716
summarize_communities,
1817
)
19-
from graphrag.index.operations.summarize_communities.community_reports_extractor import (
18+
from graphrag.index.operations.summarize_communities.community_reports_extractor.prep_community_report_context import (
2019
prep_community_report_context,
2120
)
2221
from graphrag.index.operations.summarize_communities.community_reports_extractor.schemas import (
@@ -39,9 +38,6 @@
3938
NODE_ID,
4039
NODE_NAME,
4140
)
42-
from graphrag.index.operations.summarize_communities.community_reports_extractor.utils import (
43-
get_levels,
44-
)
4541

4642

4743
async def create_final_community_reports(
@@ -66,35 +62,26 @@ async def create_final_community_reports(
6662
if claims_input is not None:
6763
claims = _prep_claims(claims_input)
6864

65+
max_input_length = summarization_strategy.get(
66+
"max_input_length", defaults.COMMUNITY_REPORT_MAX_INPUT_LENGTH
67+
)
68+
6969
local_contexts = prepare_community_reports(
7070
nodes,
7171
edges,
7272
claims,
7373
callbacks,
74-
summarization_strategy.get("max_input_length", 16_000),
74+
max_input_length,
7575
)
7676

77-
community_hierarchy = restore_community_hierarchy(nodes)
78-
levels = get_levels(nodes)
79-
80-
level_contexts = []
81-
for level in levels:
82-
level_context = prep_community_report_context(
83-
local_context_df=local_contexts,
84-
community_hierarchy_df=community_hierarchy,
85-
level=level,
86-
max_tokens=summarization_strategy.get(
87-
"max_input_tokens", defaults.COMMUNITY_REPORT_MAX_INPUT_LENGTH
88-
),
89-
)
90-
level_contexts.append(level_context)
91-
9277
community_reports = await summarize_communities(
78+
nodes,
9379
local_contexts,
94-
level_contexts,
80+
prep_community_report_context,
9581
callbacks,
9682
cache,
9783
summarization_strategy,
84+
max_input_length=max_input_length,
9885
async_mode=async_mode,
9986
num_threads=num_threads,
10087
)

graphrag/index/flows/create_final_community_reports_text.py

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,8 @@
1313
from graphrag.config import defaults
1414
from graphrag.config.enums import AsyncType
1515
from graphrag.index.operations.summarize_communities import (
16-
restore_community_hierarchy,
1716
summarize_communities,
1817
)
19-
from graphrag.index.operations.summarize_communities.community_reports_extractor.utils import (
20-
get_levels,
21-
)
2218
from graphrag.index.operations.summarize_communities_text.context_builder import (
2319
prep_community_report_context,
2420
prep_local_context,
@@ -46,36 +42,25 @@ async def create_final_community_reports_text(
4642
nodes_df = nodes_input.merge(entities_df, on="id")
4743
nodes = nodes_df.loc[nodes_df.loc[:, "community"] != -1]
4844

49-
max_input_length = summarization_strategy.get("max_input_length", 16_000)
50-
5145
# TEMP: forcing override of the prompt until we can put it into config
5246
summarization_strategy["extraction_prompt"] = COMMUNITY_REPORT_PROMPT
53-
# build initial local context for all communities
47+
48+
max_input_length = summarization_strategy.get(
49+
"max_input_length", defaults.COMMUNITY_REPORT_MAX_INPUT_LENGTH
50+
)
51+
5452
local_contexts = prep_local_context(
5553
communities, text_units, nodes, max_input_length
5654
)
5755

58-
community_hierarchy = restore_community_hierarchy(nodes)
59-
levels = get_levels(nodes)
60-
61-
level_contexts = []
62-
for level in levels:
63-
level_context = prep_community_report_context(
64-
local_context_df=local_contexts,
65-
community_hierarchy_df=community_hierarchy,
66-
level=level,
67-
max_tokens=summarization_strategy.get(
68-
"max_input_tokens", defaults.COMMUNITY_REPORT_MAX_INPUT_LENGTH
69-
),
70-
)
71-
level_contexts.append(level_context)
72-
7356
community_reports = await summarize_communities(
57+
nodes,
7458
local_contexts,
75-
level_contexts,
59+
prep_community_report_context,
7660
callbacks,
7761
cache,
7862
summarization_strategy,
63+
max_input_length=max_input_length,
7964
async_mode=async_mode,
8065
num_threads=num_threads,
8166
)

graphrag/index/operations/summarize_communities/community_reports_extractor/prep_community_report_context.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131

3232
def prep_community_report_context(
33+
report_df: pd.DataFrame | None,
3334
community_hierarchy_df: pd.DataFrame,
3435
local_context_df: pd.DataFrame,
3536
level: int,
@@ -42,8 +43,6 @@ def prep_community_report_context(
4243
- Check if local context fits within the limit, if yes use local context
4344
- If local context exceeds the limit, iteratively replace local context with sub-community reports, starting from the biggest sub-community
4445
"""
45-
report_df = pd.DataFrame()
46-
4746
# Filter by community level
4847
level_context_df = local_context_df.loc[
4948
local_context_df.loc[:, schemas.COMMUNITY_LEVEL] == level
@@ -62,7 +61,7 @@ def prep_community_report_context(
6261
if invalid_context_df.empty:
6362
return valid_context_df
6463

65-
if report_df.empty:
64+
if report_df is None or report_df.empty:
6665
invalid_context_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context(
6766
invalid_context_df, max_tokens
6867
)

graphrag/index/operations/summarize_communities/summarize_communities.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""A module containing create_community_reports and load_strategy methods definition."""
55

66
import logging
7+
from collections.abc import Callable
78

89
import pandas as pd
910

@@ -12,6 +13,12 @@
1213
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
1314
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
1415
from graphrag.config.enums import AsyncType
16+
from graphrag.index.operations.summarize_communities.community_reports_extractor.utils import (
17+
get_levels,
18+
)
19+
from graphrag.index.operations.summarize_communities.restore_community_hierarchy import (
20+
restore_community_hierarchy,
21+
)
1522
from graphrag.index.operations.summarize_communities.typing import (
1623
CommunityReport,
1724
CommunityReportsStrategy,
@@ -24,11 +31,13 @@
2431

2532

2633
async def summarize_communities(
34+
nodes: pd.DataFrame,
2735
local_contexts,
28-
level_contexts,
36+
level_context_builder: Callable,
2937
callbacks: WorkflowCallbacks,
3038
cache: PipelineCache,
3139
strategy: dict,
40+
max_input_length: int,
3241
async_mode: AsyncType = AsyncType.AsyncIO,
3342
num_threads: int = 4,
3443
):
@@ -37,6 +46,20 @@ async def summarize_communities(
3746
tick = progress_ticker(callbacks.progress, len(local_contexts))
3847
runner = load_strategy(strategy["type"])
3948

49+
community_hierarchy = restore_community_hierarchy(nodes)
50+
levels = get_levels(nodes)
51+
52+
level_contexts = []
53+
for level in levels:
54+
level_context = level_context_builder(
55+
pd.DataFrame(reports),
56+
community_hierarchy_df=community_hierarchy,
57+
local_context_df=local_contexts,
58+
level=level,
59+
max_tokens=max_input_length,
60+
)
61+
level_contexts.append(level_context)
62+
4063
for level_context in level_contexts:
4164

4265
async def run_generate(record):

graphrag/index/operations/summarize_communities_text/context_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ def prep_local_context(
7676

7777

7878
def prep_community_report_context(
79-
local_context_df: pd.DataFrame,
79+
report_df: pd.DataFrame | None,
8080
community_hierarchy_df: pd.DataFrame,
81+
local_context_df: pd.DataFrame,
8182
level: int,
82-
report_df: pd.DataFrame | None = None,
8383
max_tokens: int = 16000,
8484
) -> pd.DataFrame:
8585
"""

0 commit comments

Comments
 (0)