Skip to content

Commit bd97987

Browse files
authored
Reduce find_indirect_clusters() runtime through neighborhood detection and sampling (#2144)
Like @mludvig found in #2071, I notice the find_indirect_clusters' use of exhaustive depth-first search to be a significant barrier to use on any, even moderately substantial, knowledge graph. That PR uses BFS to identify a set of disjoint clusters involving the source node (each node appears in at most one cluster) whereas the original find_indirect_clusters identifies all sets of clusters up to length depth_limit from each node. @mludvig, if I'm out of line here, please correct me! The approach in this PR instead identifies neighborhoods in the graph using a Leiden clustering algorithm and samples from the neighborhoods. I believe this to be a better approach - in my testing it is even faster than BFS, and the function returns something more in line to the original `find_indirect_clusters` implementation. I would have preferred (and, in fact, originally tried) using a Clique Percolation Method approach because it allows nodes to belong to multiple neighborhoods; however CPM also ends up running into NP-hard / NP-complete runtime issues. This PR does add two dependencies to Ragas - networkx and scikit-network.
1 parent 861dcff commit bd97987

File tree

4 files changed

+621
-103
lines changed

4 files changed

+621
-103
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ dependencies = [
1919
"instructor",
2020
"gitpython",
2121
"pillow>=10.4.0",
22+
"networkx",
23+
"scikit-network",
2224

2325
# LangChain ecosystem
2426
"langchain",

src/ragas/testset/graph.py

Lines changed: 191 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import json
22
import typing as t
33
import uuid
4+
from collections import defaultdict
45
from copy import deepcopy
56
from dataclasses import dataclass, field
67
from enum import Enum
78
from pathlib import Path
89

910
from pydantic import BaseModel, Field, field_serializer
11+
from tqdm.auto import tqdm
1012

1113

1214
class UUIDEncoder(json.JSONEncoder):
@@ -250,67 +252,216 @@ def __repr__(self) -> str:
250252
def __str__(self) -> str:
251253
return self.__repr__()
252254

255+
def get_node_by_id(self, node_id: t.Union[uuid.UUID, str]) -> t.Optional[Node]:
256+
"""
257+
Retrieves a node by its ID.
258+
259+
Parameters
260+
----------
261+
node_id : uuid.UUID
262+
The ID of the node to retrieve.
263+
264+
Returns
265+
-------
266+
Node or None
267+
The node with the specified ID, or None if not found.
268+
"""
269+
if isinstance(node_id, str):
270+
node_id = uuid.UUID(node_id)
271+
272+
return next(filter(lambda n: n.id == node_id, self.nodes), None)
273+
253274
def find_indirect_clusters(
254275
self,
255276
relationship_condition: t.Callable[[Relationship], bool] = lambda _: True,
256277
depth_limit: int = 3,
257278
) -> t.List[t.Set[Node]]:
258279
"""
259-
Finds indirect clusters of nodes in the knowledge graph based on a relationship condition.
260-
Here if A -> B -> C -> D, then A, B, C, and D form a cluster. If there's also a path A -> B -> C -> E,
261-
it will form a separate cluster.
280+
Finds "indirect clusters" of nodes in the knowledge graph based on a relationship condition.
281+
Uses Leiden algorithm for community detection and identifies unique paths within each cluster.
282+
283+
NOTE: "indirect clusters" as used in the method name are
284+
"groups of nodes that are not directly connected
285+
but share a common relationship through other nodes",
286+
while the Leiden algorithm is a "clustering" algorithm that defines
287+
neighborhoods of nodes based on their connections --
288+
these definitions of "cluster" are NOT equivalent.
262289
263290
Parameters
264291
----------
265292
relationship_condition : Callable[[Relationship], bool], optional
266293
A function that takes a Relationship and returns a boolean, by default lambda _: True
294+
depth_limit : int, optional
295+
The maximum depth of relationships (number of edges) to consider for clustering, by default 3.
267296
268297
Returns
269298
-------
270299
List[Set[Node]]
271300
A list of sets, where each set contains nodes that form a cluster.
272301
"""
273-
clusters = []
274-
visited_paths = set()
275-
276-
relationships = [
277-
rel for rel in self.relationships if relationship_condition(rel)
278-
]
279302

280-
def dfs(node: Node, cluster: t.Set[Node], depth: int, path: t.Tuple[Node, ...]):
281-
if depth >= depth_limit or path in visited_paths:
282-
return
283-
visited_paths.add(path)
284-
cluster.add(node)
303+
import networkx as nx
304+
305+
def get_node_clusters(
306+
relationships: list[Relationship],
307+
) -> dict[int, set[uuid.UUID]]:
308+
"""Identify clusters of nodes using Leiden algorithm."""
309+
import numpy as np
310+
from sknetwork.clustering import Leiden
311+
from sknetwork.data import Dataset as SKDataset, from_edge_list
312+
313+
# NOTE: the upstream sknetwork Dataset has some issues with type hints,
314+
# so we use type: ignore to bypass them.
315+
graph: SKDataset = from_edge_list( # type: ignore
316+
[(str(rel.source.id), str(rel.target.id)) for rel in relationships],
317+
directed=True,
318+
)
285319

320+
# Apply Leiden clustering
321+
leiden = Leiden(random_state=42)
322+
cluster_labels: np.ndarray = leiden.fit_predict(graph.adjacency)
323+
324+
# Group nodes by cluster
325+
clusters: defaultdict[int, set[uuid.UUID]] = defaultdict(set)
326+
for label, node_id in zip(cluster_labels, graph.names):
327+
clusters[int(label)].add(uuid.UUID(node_id))
328+
329+
return dict(clusters)
330+
331+
def to_nx_digraph(
332+
nodes: set[uuid.UUID], relationships: list[Relationship]
333+
) -> nx.DiGraph:
334+
"""Convert a set of nodes and relationships to a directed graph."""
335+
# Create directed subgraph for this cluster
336+
graph = nx.DiGraph()
337+
for node_id in nodes:
338+
graph.add_node(
339+
node_id,
340+
node_obj=self.get_node_by_id(node_id),
341+
)
286342
for rel in relationships:
287-
neighbor = None
288-
if rel.source == node and rel.target not in cluster:
289-
neighbor = rel.target
290-
elif (
291-
rel.bidirectional
292-
and rel.target == node
293-
and rel.source not in cluster
294-
):
295-
neighbor = rel.source
296-
297-
if neighbor is not None:
298-
dfs(neighbor, cluster.copy(), depth + 1, path + (neighbor,))
299-
300-
# Add completed path-based cluster
301-
if len(cluster) > 1:
302-
clusters.append(cluster)
303-
304-
for node in self.nodes:
305-
initial_cluster = set()
306-
dfs(node, initial_cluster, 0, (node,))
307-
308-
# Remove duplicates by converting clusters to frozensets
309-
unique_clusters = [
310-
set(cluster) for cluster in set(frozenset(c) for c in clusters)
311-
]
343+
if rel.source.id in nodes and rel.target.id in nodes:
344+
graph.add_edge(rel.source.id, rel.target.id, relationship_obj=rel)
345+
return graph
346+
347+
def max_simple_paths(n: int, k: int = depth_limit) -> int:
348+
"""Estimate the number of paths up to depth_limit that would exist in a fully-connected graph of size cluster_nodes."""
349+
from math import prod
350+
351+
if n - k - 1 <= 0:
352+
return 0
353+
354+
return prod(n - i for i in range(k + 1))
355+
356+
def exhaustive_paths(
357+
graph: nx.DiGraph, depth_limit: int
358+
) -> list[list[uuid.UUID]]:
359+
"""Find all simple paths in the subgraph up to depth_limit."""
360+
import itertools
361+
362+
# Check if graph has enough nodes for meaningful paths
363+
if len(graph) < 2:
364+
return []
365+
366+
all_paths: list[list[uuid.UUID]] = []
367+
for source, target in itertools.permutations(graph.nodes(), 2):
368+
if not nx.has_path(graph, source, target):
369+
continue
370+
try:
371+
paths = nx.all_simple_paths(
372+
graph,
373+
source,
374+
target,
375+
cutoff=depth_limit,
376+
)
377+
all_paths.extend(paths)
378+
except nx.NetworkXNoPath:
379+
continue
380+
381+
return all_paths
382+
383+
def sample_paths_from_graph(
384+
graph: nx.DiGraph, depth_limit: int, sample_size: int = 1000
385+
) -> list[list[uuid.UUID]]:
386+
"""Sample random paths in the graph up to depth_limit."""
387+
# we're using a DiGraph, so we need to account for directionality
388+
# if a node has no out-paths, then it will cause an error in `generate_random_paths`
389+
390+
# Iteratively remove nodes with no out-paths to handle cascading effects
391+
while True:
392+
nodes_with_no_outpaths = [
393+
n for n in graph.nodes() if graph.out_degree(n) == 0
394+
]
395+
if not nodes_with_no_outpaths:
396+
break
397+
graph.remove_nodes_from(nodes_with_no_outpaths)
398+
399+
# Check if graph is empty after node removal
400+
if len(graph) == 0:
401+
return []
402+
403+
sampled_paths: list[list[uuid.UUID]] = []
404+
for depth in range(2, depth_limit + 1):
405+
# Additional safety check before generating paths
406+
if (
407+
len(graph) < depth + 1
408+
): # Need at least depth+1 nodes for a path of length depth
409+
continue
410+
411+
paths = nx.generate_random_paths(
412+
graph,
413+
sample_size=sample_size,
414+
path_length=depth,
415+
)
416+
sampled_paths.extend(paths)
417+
return sampled_paths
418+
419+
# depth 2: 3 nodes, 2 edges (A -> B -> C)
420+
if depth_limit < 2:
421+
raise ValueError("Depth limit must be at least 2")
422+
423+
# Filter relationships based on the condition
424+
filtered_relationships: list[Relationship] = []
425+
relationship_map: defaultdict[uuid.UUID, set[uuid.UUID]] = defaultdict(set)
426+
for rel in self.relationships:
427+
if relationship_condition(rel):
428+
filtered_relationships.append(rel)
429+
relationship_map[rel.source.id].add(rel.target.id)
430+
if rel.bidirectional:
431+
relationship_map[rel.target.id].add(rel.source.id)
432+
433+
if not filtered_relationships:
434+
return []
435+
436+
clusters = get_node_clusters(filtered_relationships)
437+
438+
# For each cluster, find valid paths up to depth_limit
439+
cluster_sets: set[frozenset] = set()
440+
for _cluster_label, cluster_nodes in tqdm(
441+
clusters.items(), desc="Processing clusters"
442+
):
443+
if len(cluster_nodes) < depth_limit:
444+
continue
445+
446+
subgraph = to_nx_digraph(
447+
nodes=cluster_nodes, relationships=filtered_relationships
448+
)
312449

313-
return unique_clusters
450+
sampled_paths: list[list[uuid.UUID]] = []
451+
# if the expected number of paths is small, use exhaustive search
452+
# otherwise sample with random walks
453+
if max_simple_paths(n=len(cluster_nodes), k=depth_limit) < 1000:
454+
sampled_paths.extend(exhaustive_paths(subgraph, depth_limit))
455+
else:
456+
sampled_paths.extend(sample_paths_from_graph(subgraph, depth_limit))
457+
458+
# convert paths (node IDs) to sets of Node objects
459+
# and deduplicate
460+
for path in sampled_paths:
461+
path_nodes = {subgraph.nodes[node_id]["node_obj"] for node_id in path}
462+
cluster_sets.add(frozenset(path_nodes))
463+
464+
return [set(path_nodes) for path_nodes in cluster_sets]
314465

315466
def remove_node(
316467
self, node: Node, inplace: bool = True

0 commit comments

Comments
 (0)