Skip to content

Commit 0440580

Browse files
Add Parent to communities in data model (#1491)
* Add Parent to communities in data model * Semver * Pyright * Update docs * Use leiden cluster parent id * Format
1 parent 61816e0 commit 0440580

17 files changed

+60
-26
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Add Parent id to communities data model"
4+
}

docs/index/outputs.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ This is a list of the final communities generated by Leiden. Communities are str
1616
| name | type | description |
1717
| ---------------- | ----- | ----------- |
1818
| community | int | Leiden-generated cluster ID for the community. Note that these increment with depth, so they are unique through all levels of the community hierarchy. For this table, human_readable_id is a copy of the community ID rather than a plain increment. |
19+
| parent | int | Parent community ID.|
1920
| level | int | Depth of the community in the hierarchy. |
2021
| title | str | Friendly name of the community. |
2122
| entity_ids | str[] | List of entities that are members of the community. |
@@ -30,6 +31,7 @@ This is the list of summarized reports for each community.
3031
| name | type | description |
3132
| ----------------- | ----- | ----------- |
3233
| community | int | Short ID of the community this report applies to. |
34+
| parent | int | Parent community ID. |
3335
| level | int | Level of the community this report applies to. |
3436
| title | str | LM-generated title for the report. |
3537
| summary | str | LM-generated summary of the report. |

graphrag/index/flows/create_base_entity_graph.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
"""All the steps to create the base entity graph."""
55

6-
from typing import Any, cast
6+
from typing import Any
77
from uuid import uuid4
88

99
import networkx as nx
@@ -159,12 +159,10 @@ def _prep_edges(relationships, summaries) -> pd.DataFrame:
159159

160160

161161
def _prep_communities(communities) -> pd.DataFrame:
162-
base_communities = pd.DataFrame(
163-
communities, columns=cast("Any", ["level", "community", "title"])
164-
)
165-
base_communities = base_communities.explode("title")
166-
base_communities["community"] = base_communities["community"].astype(int)
167-
return base_communities
162+
# Convert the input into a DataFrame and explode the title column
163+
return pd.DataFrame(
164+
communities, columns=pd.Index(["level", "community", "parent", "title"])
165+
).explode("title")
168166

169167

170168
def _compute_degree(graph: nx.Graph) -> pd.DataFrame:

graphrag/index/flows/create_final_communities.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,23 @@ def create_final_communities(
3838
matched = targets.loc[targets["community_x"] == targets["community_y"]]
3939
text_units = matched.explode("text_unit_ids")
4040
grouped = (
41-
text_units.groupby(["community_x", "level_x"])
41+
text_units.groupby(["community_x", "level_x", "parent_x"])
4242
.agg(relationship_ids=("id", list), text_unit_ids=("text_unit_ids", list))
4343
.reset_index()
4444
)
4545
grouped.rename(
46-
columns={"community_x": "community", "level_x": "level"}, inplace=True
46+
columns={
47+
"community_x": "community",
48+
"level_x": "level",
49+
"parent_x": "parent",
50+
},
51+
inplace=True,
4752
)
4853
all_grouped = pd.concat([
4954
all_grouped,
50-
grouped.loc[:, ["community", "level", "relationship_ids", "text_unit_ids"]],
55+
grouped.loc[
56+
:, ["community", "level", "parent", "relationship_ids", "text_unit_ids"]
57+
],
5158
])
5259

5360
# deduplicate the lists
@@ -63,6 +70,7 @@ def create_final_communities(
6370
communities["id"] = [str(uuid4()) for _ in range(len(communities))]
6471
communities["human_readable_id"] = communities["community"]
6572
communities["title"] = "Community " + communities["community"].astype(str)
73+
communities["parent"] = communities["parent"].astype(int)
6674

6775
# add fields for incremental update tracking
6876
communities["period"] = datetime.now(timezone.utc).date().isoformat()
@@ -74,6 +82,7 @@ def create_final_communities(
7482
"id",
7583
"human_readable_id",
7684
"community",
85+
"parent",
7786
"level",
7887
"title",
7988
"entity_ids",

graphrag/index/flows/create_final_community_reports.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ async def create_final_community_reports(
8888

8989
# Merge with communities to add size and period
9090
merged = community_reports.merge(
91-
communities.loc[:, ["community", "size", "period"]],
91+
communities.loc[:, ["community", "parent", "size", "period"]],
9292
on="community",
9393
how="left",
9494
copy=False,
@@ -99,6 +99,7 @@ async def create_final_community_reports(
9999
"id",
100100
"human_readable_id",
101101
"community",
102+
"parent",
102103
"level",
103104
"title",
104105
"summary",
@@ -124,7 +125,7 @@ def _prep_nodes(input: pd.DataFrame) -> pd.DataFrame:
124125
)
125126

126127
# Create NODE_DETAILS column
127-
input[NODE_DETAILS] = input.loc[
128+
input.loc[:, NODE_DETAILS] = input.loc[
128129
:, [NODE_ID, NODE_NAME, NODE_DESCRIPTION, NODE_DEGREE]
129130
].to_dict(orient="records")
130131

@@ -136,7 +137,7 @@ def _prep_edges(input: pd.DataFrame) -> pd.DataFrame:
136137
input.fillna(value={NODE_DESCRIPTION: "No Description"}, inplace=True)
137138

138139
# Create EDGE_DETAILS column
139-
input[EDGE_DETAILS] = input.loc[
140+
input.loc[:, EDGE_DETAILS] = input.loc[
140141
:, [EDGE_ID, EDGE_SOURCE, EDGE_TARGET, EDGE_DESCRIPTION, EDGE_DEGREE]
141142
].to_dict(orient="records")
142143

@@ -148,7 +149,7 @@ def _prep_claims(input: pd.DataFrame) -> pd.DataFrame:
148149
input.fillna(value={NODE_DESCRIPTION: "No Description"}, inplace=True)
149150

150151
# Create CLAIM_DETAILS column
151-
input[CLAIM_DETAILS] = input.loc[
152+
input.loc[:, CLAIM_DETAILS] = input.loc[
152153
:, [CLAIM_ID, CLAIM_SUBJECT, CLAIM_TYPE, CLAIM_STATUS, CLAIM_DESCRIPTION]
153154
].to_dict(orient="records")
154155

graphrag/index/operations/cluster_graph.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from graphrag.index.graph.utils import stable_largest_connected_component
1313

14-
Communities = list[tuple[int, str, list[str]]]
14+
Communities = list[tuple[int, int, int, list[str]]]
1515

1616

1717
class GraphCommunityStrategyType(str, Enum):
@@ -41,25 +41,25 @@ def run_layout(strategy: dict[str, Any], graph: nx.Graph) -> Communities:
4141
log.warning("Graph has no nodes")
4242
return []
4343

44-
clusters: dict[int, dict[str, list[str]]] = {}
44+
clusters: dict[int, dict[int, list[str]]] = {}
4545
strategy_type = strategy.get("type", GraphCommunityStrategyType.leiden)
4646
match strategy_type:
4747
case GraphCommunityStrategyType.leiden:
48-
clusters = run_leiden(graph, strategy)
48+
clusters, parent_mapping = run_leiden(graph, strategy)
4949
case _:
5050
msg = f"Unknown clustering strategy {strategy_type}"
5151
raise ValueError(msg)
5252

5353
results: Communities = []
5454
for level in clusters:
5555
for cluster_id, nodes in clusters[level].items():
56-
results.append((level, cluster_id, nodes))
56+
results.append((level, cluster_id, parent_mapping[cluster_id], nodes))
5757
return results
5858

5959

6060
def run_leiden(
6161
graph: nx.Graph, args: dict[str, Any]
62-
) -> dict[int, dict[str, list[str]]]:
62+
) -> tuple[dict[int, dict[int, list[str]]], dict[int, int]]:
6363
"""Run method definition."""
6464
max_cluster_size = args.get("max_cluster_size", 10)
6565
use_lcc = args.get("use_lcc", True)
@@ -68,7 +68,7 @@ def run_leiden(
6868
"Running leiden with max_cluster_size=%s, lcc=%s", max_cluster_size, use_lcc
6969
)
7070

71-
node_id_to_community_map = _compute_leiden_communities(
71+
node_id_to_community_map, community_hierarchy_map = _compute_leiden_communities(
7272
graph=graph,
7373
max_cluster_size=max_cluster_size,
7474
use_lcc=use_lcc,
@@ -80,16 +80,16 @@ def run_leiden(
8080
if levels is None:
8181
levels = sorted(node_id_to_community_map.keys())
8282

83-
results_by_level: dict[int, dict[str, list[str]]] = {}
83+
results_by_level: dict[int, dict[int, list[str]]] = {}
8484
for level in levels:
8585
result = {}
8686
results_by_level[level] = result
8787
for node_id, raw_community_id in node_id_to_community_map[level].items():
88-
community_id = str(raw_community_id)
88+
community_id = raw_community_id
8989
if community_id not in result:
9090
result[community_id] = []
9191
result[community_id].append(node_id)
92-
return results_by_level
92+
return results_by_level, community_hierarchy_map
9393

9494

9595
# Taken from graph_intelligence & adapted
@@ -98,8 +98,8 @@ def _compute_leiden_communities(
9898
max_cluster_size: int,
9999
use_lcc: bool,
100100
seed=0xDEADBEEF,
101-
) -> dict[int, dict[str, int]]:
102-
"""Return Leiden root communities."""
101+
) -> tuple[dict[int, dict[str, int]], dict[int, int]]:
102+
"""Return Leiden root communities and their hierarchy mapping."""
103103
# NOTE: This import is done here to reduce the initial import time of the graphrag package
104104
from graspologic.partition import hierarchical_leiden
105105

@@ -110,8 +110,13 @@ def _compute_leiden_communities(
110110
graph, max_cluster_size=max_cluster_size, random_seed=seed
111111
)
112112
results: dict[int, dict[str, int]] = {}
113+
hierarchy: dict[int, int] = {}
113114
for partition in community_mapping:
114115
results[partition.level] = results.get(partition.level, {})
115116
results[partition.level][partition.node] = partition.cluster
116117

117-
return results
118+
hierarchy[partition.cluster] = (
119+
partition.parent_cluster if partition.parent_cluster is not None else -1
120+
)
121+
122+
return results, hierarchy

graphrag/index/update/communities.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def _merge_and_resolve_nodes(
5454
v: v + old_max_community_id + 1
5555
for k, v in delta_nodes["community"].dropna().astype(int).items()
5656
}
57+
community_id_mapping.update({-1: -1})
5758

5859
delta_nodes["community"] = delta_nodes["community"].where(
5960
delta_nodes["community"].isna(),
@@ -130,6 +131,12 @@ def _update_and_merge_communities(
130131
.apply(lambda x: community_id_mapping.get(x, x))
131132
)
132133

134+
delta_communities["parent"] = (
135+
delta_communities["parent"]
136+
.astype(int)
137+
.apply(lambda x: community_id_mapping.get(x, x))
138+
)
139+
133140
old_communities["community"] = old_communities["community"].astype(int)
134141

135142
# Merge the final communities
@@ -150,6 +157,7 @@ def _update_and_merge_communities(
150157
"id",
151158
"human_readable_id",
152159
"community",
160+
"parent",
153161
"level",
154162
"title",
155163
"entity_ids",
@@ -201,6 +209,12 @@ def _update_and_merge_community_reports(
201209
.apply(lambda x: community_id_mapping.get(x, x))
202210
)
203211

212+
delta_community_reports["parent"] = (
213+
delta_community_reports["parent"]
214+
.astype(int)
215+
.apply(lambda x: community_id_mapping.get(x, x))
216+
)
217+
204218
old_community_reports["community"] = old_community_reports["community"].astype(int)
205219

206220
# Merge the final community reports
@@ -223,6 +237,7 @@ def _update_and_merge_community_reports(
223237
"id",
224238
"human_readable_id",
225239
"community",
240+
"parent",
226241
"level",
227242
"title",
228243
"summary",
-500 Bytes
Binary file not shown.
-592 Bytes
Binary file not shown.
1.56 KB
Binary file not shown.

0 commit comments

Comments
 (0)