Skip to content

Commit 3d1a0d2

Browse files
anistarkaubford
andauthored
Knowledge graph/optimize for large corpus (#2267)
contd... #1967 --------- Co-authored-by: Aubrey <[email protected]>
1 parent 0710f21 commit 3d1a0d2

File tree

7 files changed

+1452
-16
lines changed

7 files changed

+1452
-16
lines changed

.gitignore

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,11 @@ cython_debug/
164164
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
165165
.idea/
166166

167+
# Cursor
168+
.cursorignore
169+
.cursor/*
170+
!.cursor/rules/
171+
167172
# Ragas specific
168173
_experiments/
169174
**/fil-result/
@@ -174,7 +179,8 @@ examples/ragas_examples/_version.py
174179
.envrc
175180
uv.lock
176181
.cache/
177-
.claude
182+
.claude/*
183+
!.claude/commands/
178184
node_modules
179185

180186
# Ragas examples

src/ragas/testset/graph.py

Lines changed: 186 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import hashlib
12
import json
3+
import random
24
import typing as t
35
import uuid
46
from collections import defaultdict
@@ -312,19 +314,21 @@ def get_node_clusters(
312314

313315
# NOTE: the upstream sknetwork Dataset has some issues with type hints,
314316
# so we use type: ignore to bypass them.
317+
# Use hex representation to ensure proper UUID strings for clustering
315318
graph: SKDataset = from_edge_list( # type: ignore
316-
[(str(rel.source.id), str(rel.target.id)) for rel in relationships],
319+
[(rel.source.id.hex, rel.target.id.hex) for rel in relationships],
317320
directed=True,
318321
)
319322

320323
# Apply Leiden clustering
321324
leiden = Leiden(random_state=42)
322-
cluster_labels: np.ndarray = leiden.fit_predict(graph.adjacency)
325+
cluster_labels: np.ndarray = leiden.fit_predict(graph["adjacency"])
323326

324327
# Group nodes by cluster
325328
clusters: defaultdict[int, set[uuid.UUID]] = defaultdict(set)
326-
for label, node_id in zip(cluster_labels, graph.names):
327-
clusters[int(label)].add(uuid.UUID(node_id))
329+
for label, node_id_hex in zip(cluster_labels, graph["names"]):
330+
# node_id_hex is the hex string representation of the UUID
331+
clusters[int(label)].add(uuid.UUID(hex=node_id_hex))
328332

329333
return dict(clusters)
330334

@@ -440,7 +444,8 @@ def sample_paths_from_graph(
440444
for _cluster_label, cluster_nodes in tqdm(
441445
clusters.items(), desc="Processing clusters"
442446
):
443-
if len(cluster_nodes) < depth_limit:
447+
# Skip clusters that are too small to form any meaningful paths (need at least 2 nodes)
448+
if len(cluster_nodes) < 2:
444449
continue
445450

446451
subgraph = to_nx_digraph(
@@ -463,6 +468,182 @@ def sample_paths_from_graph(
463468

464469
return [set(path_nodes) for path_nodes in cluster_sets]
465470

471+
def find_n_indirect_clusters(
472+
self,
473+
n: int,
474+
relationship_condition: t.Callable[[Relationship], bool] = lambda _: True,
475+
depth_limit: int = 3,
476+
) -> t.List[t.Set[Node]]:
477+
"""
478+
Return n indirect clusters of nodes in the knowledge graph based on a relationship condition.
479+
Optimized for large datasets by using an adjacency index for lookups and limiting path exploration
480+
relative to n.
481+
482+
A cluster represents a path through the graph. For example, if A -> B -> C -> D exists in the graph,
483+
then {A, B, C, D} forms a cluster. If there's also a path A -> B -> C -> E, it forms a separate cluster.
484+
485+
The method returns a list of up to n sets, where each set contains nodes forming a complete path
486+
from a starting node to a leaf node or a path segment up to depth_limit nodes long. The result may contain
487+
fewer than n clusters if the graph is very sparse or if there aren't enough nodes to form n distinct clusters.
488+
489+
To maximize diversity in the results:
490+
1. Random starting nodes are selected
491+
2. Paths from each starting node are grouped
492+
3. Clusters are selected in round-robin fashion from each group until n unique clusters are found
493+
4. Duplicate clusters are eliminated
494+
5. When a superset cluster is found (e.g., {A,B,C,D}), any existing subset clusters (e.g., {A,B,C})
495+
are removed to avoid redundancy
496+
497+
Parameters
498+
----------
499+
n : int
500+
Target number of clusters to return. Must be at least 1. Should return n clusters unless the graph is
501+
extremely sparse.
502+
relationship_condition : Callable[[Relationship], bool], optional
503+
A function that takes a Relationship and returns a boolean, by default lambda _: True
504+
depth_limit : int, optional
505+
Maximum depth for path exploration, by default 3. Must be at least 2 to form clusters by definition.
506+
507+
Returns
508+
-------
509+
List[Set[Node]]
510+
A list of sets, where each set contains nodes that form a cluster.
511+
512+
Raises
513+
------
514+
ValueError
515+
If depth_limit < 2, n < 1, or no relationships match the provided condition.
516+
"""
517+
if depth_limit < 2:
518+
raise ValueError("depth_limit must be at least 2 to form valid clusters")
519+
520+
if n < 1:
521+
raise ValueError("n must be at least 1")
522+
523+
# Filter relationships once upfront
524+
filtered_relationships: list[Relationship] = [
525+
rel for rel in self.relationships if relationship_condition(rel)
526+
]
527+
528+
if not filtered_relationships:
529+
raise ValueError(
530+
"No relationships match the provided condition. Cannot form clusters."
531+
)
532+
533+
# Build adjacency list for faster neighbor lookup - optimized for large datasets
534+
adjacency_list: dict[Node, set[Node]] = {}
535+
unique_edges: set[frozenset[Node]] = set()
536+
for rel in filtered_relationships:
537+
# Lazy initialization since we only care about nodes with relationships
538+
if rel.source not in adjacency_list:
539+
adjacency_list[rel.source] = set()
540+
adjacency_list[rel.source].add(rel.target)
541+
unique_edges.add(frozenset({rel.source, rel.target}))
542+
if rel.bidirectional:
543+
if rel.target not in adjacency_list:
544+
adjacency_list[rel.target] = set()
545+
adjacency_list[rel.target].add(rel.source)
546+
547+
# Aggregate clusters for each start node
548+
start_node_clusters: dict[Node, set[frozenset[Node]]] = {}
549+
# sample enough starting nodes to handle worst case grouping scenario where nodes are grouped
550+
# in independent clusters of size equal to depth_limit. This only surfaces when there are less
551+
# unique edges than nodes.
552+
connected_nodes: set[Node] = set().union(*unique_edges)
553+
sample_size: int = (
554+
(n - 1) * depth_limit + 1
555+
if len(unique_edges) < len(connected_nodes)
556+
else max(n, depth_limit, 10)
557+
)
558+
559+
def dfs(node: Node, start_node: Node, current_path: t.Set[Node]):
560+
# Terminate exploration when max usable clusters is reached so complexity doesn't spiral
561+
if len(start_node_clusters.get(start_node, [])) > sample_size:
562+
return
563+
564+
current_path.add(node)
565+
path_length = len(current_path)
566+
at_max_depth = path_length >= depth_limit
567+
neighbors = adjacency_list.get(node, None)
568+
569+
# If this is a leaf node or we've reached depth limit
570+
# and we have a valid path of at least 2 nodes, add it as a cluster
571+
if path_length > 1 and (
572+
at_max_depth
573+
or not neighbors
574+
or all(n in current_path for n in neighbors)
575+
):
576+
# Lazy initialization of the set for this start node
577+
if start_node not in start_node_clusters:
578+
start_node_clusters[start_node] = set()
579+
start_node_clusters[start_node].add(frozenset(current_path))
580+
elif neighbors:
581+
for neighbor in neighbors:
582+
# Block cycles
583+
if neighbor not in current_path:
584+
dfs(neighbor, start_node, current_path)
585+
586+
# Backtrack by removing the current node from path
587+
current_path.remove(node)
588+
589+
# Shuffle nodes for random starting points
590+
# Use adjacency list since that has filtered out isolated nodes
591+
# Sort by node ID for consistent ordering while maintaining algorithm effectiveness
592+
start_nodes = sorted(adjacency_list.keys(), key=lambda n: n.id.hex)
593+
# Use a hash-based seed for reproducible but varied shuffling based on the nodes themselves
594+
node_ids_str = "".join(n.id.hex for n in start_nodes)
595+
node_hash = hashlib.sha256(node_ids_str.encode("utf-8")).hexdigest()
596+
rng = random.Random(int(node_hash[:8], 16)) # Use first 8 hex chars as seed
597+
rng.shuffle(start_nodes)
598+
samples: list[Node] = start_nodes[:sample_size]
599+
for start_node in samples:
600+
dfs(start_node, start_node, set())
601+
602+
start_node_clusters_list: list[set[frozenset[Node]]] = list(
603+
start_node_clusters.values()
604+
)
605+
606+
# Iteratively pop from each start_node_clusters until we have n unique clusters
607+
# Avoid adding duplicates and subset/superset pairs so we have diversity. We
608+
# favor supersets over subsets if we are given a choice.
609+
unique_clusters: set[frozenset[Node]] = set()
610+
i = 0
611+
while len(unique_clusters) < n and start_node_clusters_list:
612+
# Cycle through the start node clusters
613+
current_index = i % len(start_node_clusters_list)
614+
615+
current_start_node_clusters: set[frozenset[Node]] = (
616+
start_node_clusters_list[current_index]
617+
)
618+
cluster: frozenset[Node] = current_start_node_clusters.pop()
619+
620+
# Check if the new cluster is a subset of any existing cluster
621+
# and collect any existing clusters that are subsets of this cluster
622+
is_subset = False
623+
subsets_to_remove: set[frozenset[Node]] = set()
624+
625+
for existing in unique_clusters:
626+
if cluster.issubset(existing):
627+
is_subset = True
628+
break
629+
elif cluster.issuperset(existing):
630+
subsets_to_remove.add(existing)
631+
632+
# Only add the new cluster if it's not a subset of any existing cluster
633+
if not is_subset:
634+
# Remove any subsets of the new cluster
635+
unique_clusters -= subsets_to_remove
636+
unique_clusters.add(cluster)
637+
638+
# If this set is now empty, remove it
639+
if not current_start_node_clusters:
640+
start_node_clusters_list.pop(current_index)
641+
# Don't increment i since we removed an element to account for shift
642+
else:
643+
i += 1
644+
645+
return [set(cluster) for cluster in unique_clusters]
646+
466647
def remove_node(
467648
self, node: Node, inplace: bool = True
468649
) -> t.Optional["KnowledgeGraph"]:

src/ragas/testset/synthesizers/multi_hop/abstract.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from ragas.prompt import PydanticPrompt
1010
from ragas.testset.graph import KnowledgeGraph, Node
11-
from ragas.testset.graph_queries import get_child_nodes
1211
from ragas.testset.persona import Persona
1312
from ragas.testset.synthesizers.multi_hop.base import (
1413
MultiHopQuerySynthesizer,
@@ -39,11 +38,17 @@ class MultiHopAbstractQuerySynthesizer(MultiHopQuerySynthesizer):
3938
concept_combination_prompt: PydanticPrompt = ConceptCombinationPrompt()
4039
theme_persona_matching_prompt: PydanticPrompt = ThemesPersonasMatchingPrompt()
4140

42-
def get_node_clusters(self, knowledge_graph: KnowledgeGraph) -> t.List[t.Set[Node]]:
43-
"""Identify clusters of nodes based on the specified relationship condition."""
44-
node_clusters = knowledge_graph.find_indirect_clusters(
45-
relationship_condition=lambda rel: bool(
46-
rel.get_property(self.relation_property)
41+
def get_node_clusters(
42+
self,
43+
knowledge_graph: KnowledgeGraph,
44+
n: int = 1,
45+
) -> t.List[t.Set[Node]]:
46+
"""Find n indirect clusters of nodes based on relationship condition"""
47+
48+
node_clusters = knowledge_graph.find_n_indirect_clusters(
49+
n,
50+
relationship_condition=lambda rel: (
51+
True if rel.get_property(self.relation_property) else False
4752
),
4853
depth_limit=3,
4954
)
@@ -61,7 +66,7 @@ async def _generate_scenarios(
6166
Generate a list of scenarios of type MultiHopScenario.
6267
6368
Steps to generate scenarios:
64-
1. Find indirect clusters of nodes based on relationship condition
69+
1. Find n indirect clusters of nodes based on relationship condition
6570
2. Calculate the number of samples that should be created per cluster to get n samples in total
6671
3. For each cluster of nodes
6772
a. Find the child nodes of the cluster nodes
@@ -70,7 +75,7 @@ async def _generate_scenarios(
7075
4. Sample diverse combinations of scenarios to get n samples
7176
"""
7277

73-
node_clusters = self.get_node_clusters(knowledge_graph)
78+
node_clusters = self.get_node_clusters(knowledge_graph, n)
7479
scenarios = []
7580

7681
if len(node_clusters) == 0:
@@ -79,12 +84,18 @@ async def _generate_scenarios(
7984
)
8085
num_sample_per_cluster = int(np.ceil(n / len(node_clusters)))
8186

87+
child_relationships = [
88+
rel for rel in knowledge_graph.relationships if rel.type == "child"
89+
]
90+
8291
for cluster in node_clusters:
8392
if len(scenarios) >= n:
8493
break
8594
nodes = []
8695
for node in cluster:
87-
child_nodes = get_child_nodes(node, knowledge_graph, level=1)
96+
child_nodes = [
97+
rel.target for rel in child_relationships if rel.source == node
98+
]
8899
if child_nodes:
89100
nodes.extend(child_nodes)
90101
else:

tests/unit/test_graph.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,16 @@ def simple_graph(self):
173173
),
174174
(
175175
4,
176-
[], # depth_limit=4 > max(cluster_size), so no paths are identified
176+
[
177+
("A", "C"),
178+
("E", "F", "G"),
179+
("B", "C"),
180+
("A", "B"),
181+
("F", "G"),
182+
("A", "B", "C"),
183+
("E", "F"),
184+
("E", "G"),
185+
],
177186
),
178187
],
179188
)

0 commit comments

Comments
 (0)