Skip to content

Commit d840b16

Browse files
shahules786jltham
andauthored
feat: improvements in test synthesization (#1621)
PR 2 of improvements in test generation --------- Co-authored-by: Jin Lin Tham <[email protected]>
1 parent 5f74eb5 commit d840b16

File tree

32 files changed

+1336
-1151
lines changed

32 files changed

+1336
-1151
lines changed

docs/getstarted/rag_testset_generation.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,9 @@ query_distribution = default_query_distribution(generator_llm)
141141
```
142142
```
143143
[
144-
(AbstractQuerySynthesizer(llm=generator_llm), 0.25),
145-
(ComparativeAbstractQuerySynthesizer(llm=generator_llm), 0.25),
146-
(SpecificQuerySynthesizer(llm=generator_llm), 0.5),
144+
(SingleHopSpecificQuerySynthesizer(llm=llm), 0.5),
145+
(MultiHopAbstractQuerySynthesizer(llm=llm), 0.25),
146+
(MultiHopSpecificQuerySynthesizer(llm=llm), 0.25),
147147
]
148148
```
149149

docs/references/testset_schema.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
members:
1616
- BaseScenario
1717

18-
::: ragas.testset.synthesizers.specific_query.SpecificQueryScenario
18+
::: ragas.testset.synthesizers.single_hop.specific.SingleHopSpecificQuerySynthesizer
1919
options:
2020
show_root_heading: True
2121
show_root_full_path: False
2222

23-
::: ragas.testset.synthesizers.abstract_query.AbstractQueryScenario
23+
::: ragas.testset.synthesizers.multi_hop.specific.MultiHopSpecificQuerySynthesizer
2424
options:
2525
show_root_heading: True
2626
show_root_full_path: False

src/ragas/metrics/_string.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class DistanceMeasure(Enum):
1313
LEVENSHTEIN = "levenshtein"
1414
HAMMING = "hamming"
1515
JARO = "jaro"
16+
JARO_WINKLER = "jaro_winkler"
1617

1718

1819
@dataclass
@@ -77,6 +78,7 @@ def __post_init__(self):
7778
DistanceMeasure.LEVENSHTEIN: distance.Levenshtein,
7879
DistanceMeasure.HAMMING: distance.Hamming,
7980
DistanceMeasure.JARO: distance.Jaro,
81+
DistanceMeasure.JARO_WINKLER: distance.JaroWinkler,
8082
}
8183

8284
def init(self, run_config: RunConfig):

src/ragas/testset/graph.py

Lines changed: 84 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -206,11 +206,15 @@ def __repr__(self) -> str:
206206
def __str__(self) -> str:
207207
return self.__repr__()
208208

209-
def find_clusters(
210-
self, relationship_condition: t.Callable[[Relationship], bool] = lambda _: True
209+
def find_indirect_clusters(
210+
self,
211+
relationship_condition: t.Callable[[Relationship], bool] = lambda _: True,
212+
depth_limit: int = 3,
211213
) -> t.List[t.Set[Node]]:
212214
"""
213-
Finds clusters of nodes in the knowledge graph based on a relationship condition.
215+
Finds indirect clusters of nodes in the knowledge graph based on a relationship condition.
216+
Here if A -> B -> C -> D, then A, B, C, and D form a cluster. If there's also a path A -> B -> C -> E,
217+
it will form a separate cluster.
214218
215219
Parameters
216220
----------
@@ -223,31 +227,95 @@ def find_clusters(
223227
A list of sets, where each set contains nodes that form a cluster.
224228
"""
225229
clusters = []
226-
visited = set()
230+
visited_paths = set()
227231

228232
relationships = [
229233
rel for rel in self.relationships if relationship_condition(rel)
230234
]
231235

232-
def dfs(node: Node, cluster: t.Set[Node]):
233-
visited.add(node)
236+
def dfs(node: Node, cluster: t.Set[Node], depth: int, path: t.Tuple[Node, ...]):
237+
if depth >= depth_limit or path in visited_paths:
238+
return
239+
visited_paths.add(path)
234240
cluster.add(node)
241+
235242
for rel in relationships:
236-
if rel.source == node and rel.target not in visited:
237-
dfs(rel.target, cluster)
238-
# if the relationship is bidirectional, we need to check the reverse
243+
neighbor = None
244+
if rel.source == node and rel.target not in cluster:
245+
neighbor = rel.target
239246
elif (
240247
rel.bidirectional
241248
and rel.target == node
242-
and rel.source not in visited
249+
and rel.source not in cluster
243250
):
244-
dfs(rel.source, cluster)
251+
neighbor = rel.source
252+
253+
if neighbor is not None:
254+
dfs(neighbor, cluster.copy(), depth + 1, path + (neighbor,))
255+
256+
# Add completed path-based cluster
257+
if len(cluster) > 1:
258+
clusters.append(cluster)
245259

246260
for node in self.nodes:
247-
if node not in visited:
248-
cluster = set()
249-
dfs(node, cluster)
250-
if len(cluster) > 1:
261+
initial_cluster = set()
262+
dfs(node, initial_cluster, 0, (node,))
263+
264+
# Remove duplicates by converting clusters to frozensets
265+
unique_clusters = [
266+
set(cluster) for cluster in set(frozenset(c) for c in clusters)
267+
]
268+
269+
return unique_clusters
270+
271+
def find_direct_clusters(
272+
self, relationship_condition: t.Callable[[Relationship], bool] = lambda _: True
273+
) -> t.Dict[Node, t.List[t.Set[Node]]]:
274+
"""
275+
Finds direct clusters of nodes in the knowledge graph based on a relationship condition.
276+
Here if A->B, and A->C, then A, B, and C form a cluster.
277+
278+
Parameters
279+
----------
280+
relationship_condition : Callable[[Relationship], bool], optional
281+
A function that takes a Relationship and returns a boolean, by default lambda _: True
282+
283+
Returns
284+
-------
285+
List[Set[Node]]
286+
A list of sets, where each set contains nodes that form a cluster.
287+
"""
288+
289+
clusters = []
290+
relationships = [
291+
rel for rel in self.relationships if relationship_condition(rel)
292+
]
293+
for node in self.nodes:
294+
cluster = set()
295+
cluster.add(node)
296+
for rel in relationships:
297+
if rel.bidirectional:
298+
if rel.source == node:
299+
cluster.add(rel.target)
300+
elif rel.target == node:
301+
cluster.add(rel.source)
302+
else:
303+
if rel.source == node:
304+
cluster.add(rel.target)
305+
306+
if len(cluster) > 1:
307+
if cluster not in clusters:
251308
clusters.append(cluster)
252309

253-
return clusters
310+
# Remove subsets from clusters
311+
unique_clusters = []
312+
for cluster in clusters:
313+
if not any(cluster < other for other in clusters):
314+
unique_clusters.append(cluster)
315+
clusters = unique_clusters
316+
317+
cluster_dict = {}
318+
for cluster in clusters:
319+
cluster_dict.update({cluster.pop(): cluster})
320+
321+
return cluster_dict

src/ragas/testset/graph_queries.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import typing as t
2+
3+
from ragas.testset.graph import KnowledgeGraph, Node
4+
5+
6+
def get_child_nodes(node: Node, graph: KnowledgeGraph, level: int = 1) -> t.List[Node]:
7+
"""
8+
Get the child nodes of a given node up to a specified level.
9+
10+
Parameters
11+
----------
12+
node : Node
13+
The node to get the children of.
14+
graph : KnowledgeGraph
15+
The knowledge graph containing the node.
16+
level : int
17+
The maximum level to which child nodes are searched.
18+
19+
Returns
20+
-------
21+
List[Node]
22+
The list of child nodes up to the specified level.
23+
"""
24+
children = []
25+
26+
# Helper function to perform depth-limited search for child nodes
27+
def dfs(current_node: Node, current_level: int):
28+
if current_level > level:
29+
return
30+
for rel in graph.relationships:
31+
if rel.source == current_node and rel.type == "child":
32+
children.append(rel.target)
33+
dfs(rel.target, current_level + 1)
34+
35+
# Start DFS from the initial node at level 0
36+
dfs(node, 1)
37+
38+
return children
Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,29 @@
11
import typing as t
22

33
from ragas.llms import BaseRagasLLM
4-
5-
from .abstract_query import (
6-
AbstractQuerySynthesizer,
7-
ComparativeAbstractQuerySynthesizer,
4+
from ragas.testset.synthesizers.multi_hop import (
5+
MultiHopAbstractQuerySynthesizer,
6+
MultiHopSpecificQuerySynthesizer,
7+
)
8+
from ragas.testset.synthesizers.single_hop.specific import (
9+
SingleHopSpecificQuerySynthesizer,
810
)
11+
912
from .base import BaseSynthesizer
10-
from .base_query import QuerySynthesizer
11-
from .specific_query import SpecificQuerySynthesizer
1213

1314
QueryDistribution = t.List[t.Tuple[BaseSynthesizer, float]]
1415

1516

1617
def default_query_distribution(llm: BaseRagasLLM) -> QueryDistribution:
17-
"""
18-
Default query distribution for the test set.
19-
20-
By default, 25% of the queries are generated using `AbstractQuerySynthesizer`,
21-
25% are generated using `ComparativeAbstractQuerySynthesizer`, and 50% are
22-
generated using `SpecificQuerySynthesizer`.
23-
"""
18+
""" """
2419
return [
25-
(AbstractQuerySynthesizer(llm=llm), 0.25),
26-
(ComparativeAbstractQuerySynthesizer(llm=llm), 0.25),
27-
(SpecificQuerySynthesizer(llm=llm), 0.5),
20+
(SingleHopSpecificQuerySynthesizer(llm=llm), 0.5),
21+
(MultiHopAbstractQuerySynthesizer(llm=llm), 0.25),
22+
(MultiHopSpecificQuerySynthesizer(llm=llm), 0.25),
2823
]
2924

3025

3126
__all__ = [
3227
"BaseSynthesizer",
33-
"QuerySynthesizer",
34-
"AbstractQuerySynthesizer",
35-
"ComparativeAbstractQuerySynthesizer",
36-
"SpecificQuerySynthesizer",
3728
"default_query_distribution",
3829
]

0 commit comments

Comments
 (0)