Skip to content

Commit 04d9f58

Browse files
authored
Re-implement hierarchical Leiden (#2049)
* Use graspologic-native hierarchical leiden * Re-implement largest_connected_component * Copy in modularity * Use graspologic-native directly in pyproject * Remove directed graph tests (we don't use this) * Semver * Remove graspologic dep
1 parent 97704ab commit 04d9f58

File tree

8 files changed

+133
-488
lines changed

8 files changed

+133
-488
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "major",
3+
"description": "Re-implement graspologic methods to remove dependency. Remove visualization steps."
4+
}

graphrag/index/operations/cluster_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import logging
77

88
import networkx as nx
9-
from graspologic.partition import hierarchical_leiden
109

10+
from graphrag.index.utils.graphs import hierarchical_leiden
1111
from graphrag.index.utils.stable_lcc import stable_largest_connected_component
1212

1313
Communities = list[tuple[int, int, int, list[str]]]

graphrag/index/operations/prune_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55

66
from typing import TYPE_CHECKING, cast
77

8-
import graspologic as glc
98
import networkx as nx
109
import numpy as np
1110

1211
import graphrag.data_model.schemas as schemas
12+
from graphrag.index.utils.graphs import largest_connected_component
1313

1414
if TYPE_CHECKING:
1515
from networkx.classes.reportviews import DegreeView
@@ -78,7 +78,7 @@ def prune_graph(
7878
])
7979

8080
if lcc_only:
81-
return glc.utils.largest_connected_component(graph) # type: ignore
81+
return largest_connected_component(graph)
8282

8383
return graph
8484

graphrag/index/utils/graphs.py

Lines changed: 118 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,136 @@
11
# Copyright (c) 2024 Microsoft Corporation.
22
# Licensed under the MIT License
33

4-
"""Collection of graph utility functions."""
4+
"""
5+
Collection of graph utility functions.
6+
7+
These are largely copies/re-implementations of graspologic methods to avoid dependency issues.
8+
"""
59

610
import logging
7-
from typing import cast
11+
import math
12+
from collections import defaultdict
13+
from typing import Any, cast
814

15+
import graspologic_native as gn
916
import networkx as nx
1017
import numpy as np
1118
import pandas as pd
12-
from graspologic.partition import hierarchical_leiden, modularity
13-
from graspologic.utils import largest_connected_component
1419

1520
from graphrag.config.enums import ModularityMetric
1621

1722
logger = logging.getLogger(__name__)
1823

1924

25+
def largest_connected_component(graph: nx.Graph) -> nx.Graph:
26+
"""Return the largest connected component of the graph."""
27+
graph = graph.copy()
28+
lcc_nodes = max(nx.connected_components(graph), key=len)
29+
lcc = graph.subgraph(lcc_nodes).copy()
30+
lcc.remove_nodes_from([n for n in lcc if n not in lcc_nodes])
31+
return cast("nx.Graph", lcc)
32+
33+
34+
def _nx_to_edge_list(
35+
graph: nx.Graph,
36+
weight_attribute: str = "weight",
37+
weight_default: float = 1.0,
38+
) -> list[tuple[str, str, float]]:
39+
"""
40+
Convert an undirected, non-multigraph networkx graph to a list of edges.
41+
42+
Each edge is represented as a tuple of (source_str, target_str, weight).
43+
"""
44+
edge_list: list[tuple[str, str, float]] = []
45+
46+
# Decide how to retrieve the weight data
47+
edge_iter = graph.edges(data=weight_attribute, default=weight_default) # type: ignore
48+
49+
for source, target, weight in edge_iter:
50+
source_str = str(source)
51+
target_str = str(target)
52+
edge_list.append((source_str, target_str, float(weight)))
53+
54+
return edge_list
55+
56+
57+
def hierarchical_leiden(
58+
graph: nx.Graph,
59+
max_cluster_size: int = 10,
60+
random_seed: int | None = 0xDEADBEEF,
61+
) -> Any:
62+
"""Run hierarchical leiden on the graph."""
63+
return gn.hierarchical_leiden(
64+
edges=_nx_to_edge_list(graph),
65+
max_cluster_size=max_cluster_size,
66+
seed=random_seed,
67+
starting_communities=None,
68+
resolution=1.0,
69+
randomness=0.001,
70+
use_modularity=True,
71+
iterations=1,
72+
)
73+
74+
75+
def modularity(
76+
graph: nx.Graph,
77+
partitions: dict[Any, int],
78+
weight_attribute: str = "weight",
79+
resolution: float = 1.0,
80+
) -> float:
81+
"""Given an undirected graph and a dictionary of vertices to community ids, calculate the modularity."""
82+
components = _modularity_components(graph, partitions, weight_attribute, resolution)
83+
return sum(components.values())
84+
85+
86+
def _modularity_component(
87+
intra_community_degree: float,
88+
total_community_degree: float,
89+
network_degree_sum: float,
90+
resolution: float,
91+
) -> float:
92+
community_degree_ratio = math.pow(total_community_degree, 2.0) / (
93+
2.0 * network_degree_sum
94+
)
95+
return (intra_community_degree - resolution * community_degree_ratio) / (
96+
2.0 * network_degree_sum
97+
)
98+
99+
100+
def _modularity_components(
101+
graph: nx.Graph,
102+
partitions: dict[Any, int],
103+
weight_attribute: str = "weight",
104+
resolution: float = 1.0,
105+
) -> dict[int, float]:
106+
total_edge_weight = 0.0
107+
communities = set(partitions.values())
108+
109+
degree_sums_within_community: dict[int, float] = defaultdict(lambda: 0.0)
110+
degree_sums_for_community: dict[int, float] = defaultdict(lambda: 0.0)
111+
for vertex, neighbor_vertex, weight in graph.edges(data=weight_attribute):
112+
vertex_community = partitions[vertex]
113+
neighbor_community = partitions[neighbor_vertex]
114+
if vertex_community == neighbor_community:
115+
if vertex == neighbor_vertex:
116+
degree_sums_within_community[vertex_community] += weight
117+
else:
118+
degree_sums_within_community[vertex_community] += weight * 2.0
119+
degree_sums_for_community[vertex_community] += weight
120+
degree_sums_for_community[neighbor_community] += weight
121+
total_edge_weight += weight
122+
123+
return {
124+
comm: _modularity_component(
125+
degree_sums_within_community[comm],
126+
degree_sums_for_community[comm],
127+
total_edge_weight,
128+
resolution,
129+
)
130+
for comm in communities
131+
}
132+
133+
20134
def calculate_root_modularity(
21135
graph: nx.Graph,
22136
max_cluster_size: int = 10,
@@ -147,9 +261,6 @@ def calculate_modularity(
147261
random_seed=random_seed,
148262
use_root_modularity=use_root_modularity,
149263
)
150-
case _:
151-
msg = f"Unknown modularity metric type: {modularity_metric}"
152-
raise ValueError(msg)
153264

154265

155266
def calculate_pmi_edge_weights(

graphrag/index/utils/stable_lcc.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,11 @@
88

99
import networkx as nx
1010

11+
from graphrag.index.utils.graphs import largest_connected_component
12+
1113

1214
def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
1315
"""Return the largest connected component of the graph, with nodes and edges sorted in a stable way."""
14-
# NOTE: The import is done here to reduce the initial import time of the module
15-
from graspologic.utils import largest_connected_component
16-
1716
graph = graph.copy()
1817
graph = cast("nx.Graph", largest_connected_component(graph))
1918
graph = normalize_node_names(graph)

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ dependencies = [
4646
"tiktoken>=0.9.0",
4747
# Data-Science
4848
"numpy>=1.25.2",
49-
"graspologic>=3.4.1",
5049
"networkx>=3.4.2",
5150
"pandas>=2.2.3",
5251
"pyarrow>=17.0.0",
@@ -66,6 +65,7 @@ dependencies = [
6665
"tqdm>=4.67.1",
6766
"textblob>=0.18.0.post0",
6867
"spacy>=3.8.4",
68+
"graspologic-native>=1.2.5",
6969
]
7070

7171
[project.optional-dependencies]
@@ -260,4 +260,4 @@ exclude = ["**/node_modules", "**/__pycache__"]
260260
asyncio_default_fixture_loop_scope = "function"
261261
asyncio_mode = "auto"
262262
timeout = 1000
263-
env_files = [".env"]
263+
env_files = [".env"]

tests/unit/indexing/graph/utils/test_stable_lcc.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,32 +20,6 @@ def test_undirected_graph_run_twice_produces_same_graph(self):
2020
nx.generate_graphml(graph_out_2)
2121
)
2222

23-
def test_directed_graph_keeps_source_target_intact(self):
24-
# create the test graph as a directed graph
25-
graph_in = self._create_strongly_connected_graph_with_edges_flipped(
26-
digraph=True
27-
)
28-
graph_out = stable_largest_connected_component(graph_in.copy())
29-
30-
# Make sure edges are the same and the direction is preserved
31-
edges_1 = [f"{edge[0]} -> {edge[1]}" for edge in graph_in.edges(data=True)]
32-
edges_2 = [f"{edge[0]} -> {edge[1]}" for edge in graph_out.edges(data=True)]
33-
34-
assert edges_1 == edges_2
35-
36-
def test_directed_graph_run_twice_produces_same_graph(self):
37-
# create the test graph as a directed graph
38-
graph_in = self._create_strongly_connected_graph_with_edges_flipped(
39-
digraph=True
40-
)
41-
graph_out_1 = stable_largest_connected_component(graph_in.copy())
42-
graph_out_2 = stable_largest_connected_component(graph_in.copy())
43-
44-
# Make sure the output is identical when run multiple times
45-
assert "".join(nx.generate_graphml(graph_out_1)) == "".join(
46-
nx.generate_graphml(graph_out_2)
47-
)
48-
4923
def _create_strongly_connected_graph(self, digraph=False):
5024
graph = nx.Graph() if not digraph else nx.DiGraph()
5125
graph.add_node("1", node_name=1)

0 commit comments

Comments
 (0)