Skip to content

Commit 981fd31

Browse files
Community children (#1704)
* Add children to the community tables * Replace NaN children with empty list * Replace subcommunity logic with built-in parent/child fields * Remove restore_community_hierarchy * Add children and frequency to migration notebook * Format * Semver * Add children to reports * Update tests --------- Co-authored-by: Alonso Guevara <[email protected]>
1 parent 35b6393 commit 981fd31

23 files changed

+118
-127
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "major",
3+
"description": "Add children to communities to avoid re-compute."
4+
}

docs/examples_notebooks/index_migration_to_v2.ipynb

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 41,
5+
"execution_count": 5,
66
"metadata": {},
77
"outputs": [],
88
"source": [
@@ -25,17 +25,17 @@
2525
},
2626
{
2727
"cell_type": "code",
28-
"execution_count": 42,
28+
"execution_count": null,
2929
"metadata": {},
3030
"outputs": [],
3131
"source": [
3232
"# This is the directory that has your settings.yaml\n",
33-
"PROJECT_DIRECTORY = \"<your project directory\""
33+
"PROJECT_DIRECTORY = \"<your project directory>\""
3434
]
3535
},
3636
{
3737
"cell_type": "code",
38-
"execution_count": 43,
38+
"execution_count": 2,
3939
"metadata": {},
4040
"outputs": [],
4141
"source": [
@@ -54,7 +54,7 @@
5454
},
5555
{
5656
"cell_type": "code",
57-
"execution_count": 44,
57+
"execution_count": 3,
5858
"metadata": {},
5959
"outputs": [],
6060
"source": [
@@ -65,7 +65,7 @@
6565
},
6666
{
6767
"cell_type": "code",
68-
"execution_count": 45,
68+
"execution_count": 4,
6969
"metadata": {},
7070
"outputs": [],
7171
"source": [
@@ -96,6 +96,30 @@
9696
" final_nodes.loc[:, [\"id\", \"degree\", \"x\", \"y\"]].groupby(\"id\").first().reset_index()\n",
9797
")\n",
9898
"final_entities = final_entities.merge(graph_props, on=\"id\", how=\"left\")\n",
99+
"# we're also persistint the frequency column\n",
100+
"final_entities[\"frequency\"] = final_entities[\"text_unit_ids\"].count()\n",
101+
"\n",
102+
"\n",
103+
"# we added children to communities to eliminate query-time reconstruction\n",
104+
"parent_grouped = final_communities.groupby(\"parent\").agg(\n",
105+
" children=(\"community\", \"unique\")\n",
106+
")\n",
107+
"final_communities = final_communities.merge(\n",
108+
" parent_grouped,\n",
109+
" left_on=\"community\",\n",
110+
" right_on=\"parent\",\n",
111+
" how=\"left\",\n",
112+
")\n",
113+
"\n",
114+
"# add children to the reports as well\n",
115+
"final_community_reports = final_community_reports.merge(\n",
116+
" parent_grouped,\n",
117+
" left_on=\"community\",\n",
118+
" right_on=\"parent\",\n",
119+
" how=\"left\",\n",
120+
")\n",
121+
"\n",
122+
"# copy children into the reports as well\n",
99123
"\n",
100124
"# we renamed all the output files for better clarity now that we don't have workflow naming constraints from DataShaper\n",
101125
"await write_table_to_storage(final_documents, \"documents\", storage)\n",

graphrag/index/flows/create_communities.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
"""All the steps to transform final communities."""
55

66
from datetime import datetime, timezone
7+
from typing import cast
78
from uuid import uuid4
89

10+
import numpy as np
911
import pandas as pd
1012

1113
from graphrag.index.operations.cluster_graph import cluster_graph
@@ -92,7 +94,21 @@ def create_communities(
9294
str
9395
)
9496
final_communities["parent"] = final_communities["parent"].astype(int)
95-
97+
# collect the children so we have a tree going both ways
98+
parent_grouped = cast(
99+
"pd.DataFrame",
100+
final_communities.groupby("parent").agg(children=("community", "unique")),
101+
)
102+
final_communities = final_communities.merge(
103+
parent_grouped,
104+
left_on="community",
105+
right_on="parent",
106+
how="left",
107+
)
108+
# replace NaN children with empty list
109+
final_communities["children"] = final_communities["children"].apply(
110+
lambda x: x if isinstance(x, np.ndarray) else [] # type: ignore
111+
)
96112
# add fields for incremental update tracking
97113
final_communities["period"] = datetime.now(timezone.utc).date().isoformat()
98114
final_communities["size"] = final_communities.loc[:, "entity_ids"].apply(len)
@@ -103,8 +119,9 @@ def create_communities(
103119
"id",
104120
"human_readable_id",
105121
"community",
106-
"parent",
107122
"level",
123+
"parent",
124+
"children",
108125
"title",
109126
"entity_ids",
110127
"relationship_ids",

graphrag/index/flows/create_community_reports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ async def create_community_reports(
6262

6363
community_reports = await summarize_communities(
6464
nodes,
65+
communities,
6566
local_contexts,
6667
build_level_context,
6768
callbacks,

graphrag/index/flows/create_community_reports_text.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ async def create_community_reports_text(
5353

5454
community_reports = await summarize_communities(
5555
nodes,
56+
communities,
5657
local_contexts,
5758
build_level_context,
5859
callbacks,

graphrag/index/operations/finalize_community_reports.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ def finalize_community_reports(
1313
communities: pd.DataFrame,
1414
) -> pd.DataFrame:
1515
"""All the steps to transform final community reports."""
16-
# Merge with communities to add size and period
16+
# Merge with communities to add shared fields
1717
community_reports = reports.merge(
18-
communities.loc[:, ["community", "parent", "size", "period"]],
18+
communities.loc[:, ["community", "parent", "children", "size", "period"]],
1919
on="community",
2020
how="left",
2121
copy=False,
@@ -31,8 +31,9 @@ def finalize_community_reports(
3131
"id",
3232
"human_readable_id",
3333
"community",
34-
"parent",
3534
"level",
35+
"parent",
36+
"children",
3637
"title",
3738
"summary",
3839
"full_content",

graphrag/index/operations/summarize_communities/summarize_communities.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
)
2121
from graphrag.index.operations.summarize_communities.utils import (
2222
get_levels,
23-
restore_community_hierarchy,
2423
)
2524
from graphrag.index.run.derive_from_rows import derive_from_rows
2625
from graphrag.logger.progress import progress_ticker
@@ -30,6 +29,7 @@
3029

3130
async def summarize_communities(
3231
nodes: pd.DataFrame,
32+
communities: pd.DataFrame,
3333
local_contexts,
3434
level_context_builder: Callable,
3535
callbacks: WorkflowCallbacks,
@@ -49,7 +49,12 @@ async def summarize_communities(
4949
if strategy_config.get("llm") and strategy_config["llm"]["max_retries"] == -1:
5050
strategy_config["llm"]["max_retries"] = len(nodes)
5151

52-
community_hierarchy = restore_community_hierarchy(nodes)
52+
community_hierarchy = (
53+
communities.explode("children")
54+
.rename({"children": "sub_community"}, axis=1)
55+
.loc[:, ["community", "level", "sub_community"]]
56+
).dropna()
57+
5358
levels = get_levels(nodes)
5459

5560
level_contexts = []

graphrag/index/operations/summarize_communities/utils.py

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33

44
"""A module containing community report generation utilities."""
55

6-
from itertools import pairwise
7-
86
import pandas as pd
97

108
import graphrag.model.schemas as schemas
@@ -17,48 +15,3 @@ def get_levels(
1715
levels = df[level_column].dropna().unique()
1816
levels = [int(lvl) for lvl in levels if lvl != -1]
1917
return sorted(levels, reverse=True)
20-
21-
22-
def restore_community_hierarchy(
23-
input: pd.DataFrame,
24-
name_column: str = schemas.TITLE,
25-
community_column: str = schemas.COMMUNITY_ID,
26-
level_column: str = schemas.COMMUNITY_LEVEL,
27-
) -> pd.DataFrame:
28-
"""Restore the community hierarchy from the node data."""
29-
# Group by community and level, aggregate names as lists
30-
community_df = (
31-
input.groupby([community_column, level_column])[name_column]
32-
.apply(set)
33-
.reset_index()
34-
)
35-
36-
# Build dictionary with levels as integers
37-
community_levels = {
38-
level: group.set_index(community_column)[name_column].to_dict()
39-
for level, group in community_df.groupby(level_column)
40-
}
41-
42-
# get unique levels, sorted in ascending order
43-
levels = sorted(community_levels.keys()) # type: ignore
44-
community_hierarchy = []
45-
46-
# Iterate through adjacent levels
47-
for current_level, next_level in pairwise(levels):
48-
current_communities = community_levels[current_level]
49-
next_communities = community_levels[next_level]
50-
51-
# Find sub-communities
52-
for curr_comm, curr_entities in current_communities.items():
53-
for next_comm, next_entities in next_communities.items():
54-
if next_entities.issubset(curr_entities):
55-
community_hierarchy.append({
56-
community_column: curr_comm,
57-
schemas.COMMUNITY_LEVEL: current_level,
58-
schemas.SUB_COMMUNITY: next_comm,
59-
schemas.SUB_COMMUNITY_SIZE: len(next_entities),
60-
})
61-
62-
return pd.DataFrame(
63-
community_hierarchy,
64-
)

graphrag/model/community.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,15 @@
1313
class Community(Named):
1414
"""A protocol for a community in the system."""
1515

16-
level: str = ""
16+
level: str
1717
"""Community level."""
1818

19+
parent: str
20+
"""Community ID of the parent node of this community."""
21+
22+
children: list[str]
23+
"""List of community IDs of the child nodes of this community."""
24+
1925
entity_ids: list[str] | None = None
2026
"""List of entity IDs related to the community (optional)."""
2127

@@ -25,9 +31,6 @@ class Community(Named):
2531
covariate_ids: dict[str, list[str]] | None = None
2632
"""Dictionary of different types of covariates related to the community (optional), e.g. claims"""
2733

28-
sub_community_ids: list[str] | None = None
29-
"""List of community IDs of the child nodes of this community (optional)."""
30-
3134
attributes: dict[str, Any] | None = None
3235
"""A dictionary of additional attributes associated with the community (optional). To be included in the search prompt."""
3336

@@ -48,7 +51,8 @@ def from_dict(
4851
entities_key: str = "entity_ids",
4952
relationships_key: str = "relationship_ids",
5053
covariates_key: str = "covariate_ids",
51-
sub_communities_key: str = "sub_community_ids",
54+
parent_key: str = "parent",
55+
children_key: str = "children",
5256
attributes_key: str = "attributes",
5357
size_key: str = "size",
5458
period_key: str = "period",
@@ -57,12 +61,13 @@ def from_dict(
5761
return Community(
5862
id=d[id_key],
5963
title=d[title_key],
60-
short_id=d.get(short_id_key),
6164
level=d[level_key],
65+
parent=d[parent_key],
66+
children=d[children_key],
67+
short_id=d.get(short_id_key),
6268
entity_ids=d.get(entities_key),
6369
relationship_ids=d.get(relationships_key),
6470
covariate_ids=d.get(covariates_key),
65-
sub_community_ids=d.get(sub_communities_key),
6671
attributes=d.get(attributes_key),
6772
size=d.get(size_key),
6873
period=d.get(period_key),

graphrag/model/schemas.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
# COMMUNITY HIERARCHY TABLE SCHEMA
3131
SUB_COMMUNITY = "sub_community"
32-
SUB_COMMUNITY_SIZE = "sub_community_size"
3332
COMMUNITY_LEVEL = "level"
3433

3534
# COMMUNITY CONTEXT TABLE SCHEMA

0 commit comments

Comments
 (0)