Skip to content

Commit 44f09b8

Browse files
authored
feat: improvements in default test generation (#1661)
- [x] Make sure test generation runs with short docs, long docs, with small number of docs etc - [x] Tune default settings for the above - [x] Relaxed filters for query creation
1 parent e4a88d6 commit 44f09b8

File tree

11 files changed

+243
-101
lines changed

11 files changed

+243
-101
lines changed

src/ragas/testset/persona.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import logging
2-
import random
32
import typing as t
43

54
import numpy as np
@@ -19,7 +18,7 @@ def default_filter(node: Node) -> bool:
1918
node.type.name == "DOCUMENT"
2019
and node.properties.get("summary_embedding") is not None
2120
):
22-
return random.random() < 0.25
21+
return True
2322
else:
2423
return False
2524

@@ -92,8 +91,14 @@ def generate_personas_from_kg(
9291
"""
9392

9493
nodes = [node for node in kg.nodes if filter_fn(node)]
94+
if len(nodes) == 0:
95+
raise ValueError(
96+
"No nodes that satisfied the given filer. Try changing the filter."
97+
)
98+
9599
summaries = [node.properties.get("summary") for node in nodes]
96100
summaries = [summary for summary in summaries if isinstance(summary, str)]
101+
num_personas = min(num_personas, len(summaries))
97102

98103
embeddings = []
99104
for node in nodes:

src/ragas/testset/synthesizers/__init__.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import typing as t
22

33
from ragas.llms import BaseRagasLLM
4+
from ragas.testset.graph import KnowledgeGraph
45
from ragas.testset.synthesizers.multi_hop import (
56
MultiHopAbstractQuerySynthesizer,
67
MultiHopSpecificQuerySynthesizer,
@@ -14,12 +15,24 @@
1415
QueryDistribution = t.List[t.Tuple[BaseSynthesizer, float]]
1516

1617

17-
def default_query_distribution(llm: BaseRagasLLM) -> QueryDistribution:
18-
return [
19-
(SingleHopSpecificQuerySynthesizer(llm=llm), 0.5),
20-
(MultiHopAbstractQuerySynthesizer(llm=llm), 0.25),
21-
(MultiHopSpecificQuerySynthesizer(llm=llm), 0.25),
18+
def default_query_distribution(
19+
llm: BaseRagasLLM, kg: t.Optional[KnowledgeGraph] = None
20+
) -> QueryDistribution:
21+
""" """
22+
default_queries = [
23+
SingleHopSpecificQuerySynthesizer(llm=llm),
24+
MultiHopAbstractQuerySynthesizer(llm=llm),
25+
MultiHopSpecificQuerySynthesizer(llm=llm),
2226
]
27+
if kg is not None:
28+
available_queries = []
29+
for query in default_queries:
30+
if query.get_node_clusters(kg):
31+
available_queries.append(query)
32+
else:
33+
available_queries = default_queries
34+
35+
return [(query, 1 / len(available_queries)) for query in available_queries]
2336

2437

2538
__all__ = [

src/ragas/testset/synthesizers/generate.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,7 @@
1010
from ragas._analytics import TestsetGenerationEvent, track
1111
from ragas.callbacks import new_group
1212
from ragas.cost import TokenUsageParser
13-
from ragas.embeddings.base import (
14-
BaseRagasEmbeddings,
15-
LlamaIndexEmbeddingsWrapper,
16-
)
13+
from ragas.embeddings.base import BaseRagasEmbeddings, LlamaIndexEmbeddingsWrapper
1714
from ragas.executor import Executor
1815
from ragas.llms import BaseRagasLLM, LangchainLLMWrapper, LlamaIndexLLMWrapper
1916
from ragas.run_config import RunConfig
@@ -155,6 +152,7 @@ def generate_with_langchain_docs(
155152

156153
if not transforms:
157154
transforms = default_transforms(
155+
documents=list(documents),
158156
llm=transforms_llm or self.llm,
159157
embedding_model=transforms_embedding_model,
160158
)
@@ -224,6 +222,7 @@ def generate_with_llamaindex_docs(
224222
transforms_embedding_model
225223
)
226224
transforms = default_transforms(
225+
documents=[LCDocument(page_content=doc.text) for doc in documents],
227226
llm=llm_for_transforms,
228227
embedding_model=embedding_model_for_transforms,
229228
)
@@ -312,7 +311,9 @@ def generate(
312311
if run_config is not None:
313312
self.llm.set_run_config(run_config)
314313

315-
query_distribution = query_distribution or default_query_distribution(self.llm)
314+
query_distribution = query_distribution or default_query_distribution(
315+
self.llm, self.knowledge_graph
316+
)
316317
callbacks = callbacks or []
317318

318319
# dict to store any callbacks we define

src/ragas/testset/synthesizers/multi_hop/abstract.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88

99
from ragas.prompt import PydanticPrompt
10-
from ragas.testset.graph import KnowledgeGraph
10+
from ragas.testset.graph import KnowledgeGraph, Node
1111
from ragas.testset.graph_queries import get_child_nodes
1212
from ragas.testset.persona import Persona, PersonaList
1313
from ragas.testset.synthesizers.multi_hop.base import (
@@ -42,6 +42,17 @@ class MultiHopAbstractQuerySynthesizer(MultiHopQuerySynthesizer):
4242
concept_combination_prompt: PydanticPrompt = ConceptCombinationPrompt()
4343
theme_persona_matching_prompt: PydanticPrompt = ThemesPersonasMatchingPrompt()
4444

45+
def get_node_clusters(self, knowledge_graph: KnowledgeGraph) -> t.List[t.Set[Node]]:
46+
47+
node_clusters = knowledge_graph.find_indirect_clusters(
48+
relationship_condition=lambda rel: (
49+
True if rel.get_property("summary_similarity") else False
50+
),
51+
depth_limit=3,
52+
)
53+
logger.info("found %d clusters", len(node_clusters))
54+
return node_clusters
55+
4556
async def _generate_scenarios(
4657
self,
4758
n: int,
@@ -61,18 +72,12 @@ async def _generate_scenarios(
6172
4. Sample diverse combinations of scenarios to get n samples
6273
"""
6374

64-
node_clusters = knowledge_graph.find_indirect_clusters(
65-
relationship_condition=lambda rel: (
66-
True if rel.get_property("summary_similarity") else False
67-
),
68-
depth_limit=3,
69-
)
70-
logger.info("found %d clusters", len(node_clusters))
75+
node_clusters = self.get_node_clusters(knowledge_graph)
7176
scenarios = []
7277

7378
if len(node_clusters) == 0:
7479
raise ValueError(
75-
"No clusters found in the knowledge graph. Use a different Synthesizer."
80+
"No clusters found in the knowledge graph. Try changing the relationship condition."
7681
)
7782
num_sample_per_cluster = int(np.ceil(n / len(node_clusters)))
7883

src/ragas/testset/synthesizers/multi_hop/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def prepare_combinations(
7373
valid_nodes = []
7474
for node in nodes:
7575
node_themes = [
76-
theme.lower() for theme in node.get_property(property_name)
76+
theme.lower() for theme in node.properties.get(property_name, [])
7777
]
7878
if node.get_property(property_name) and any(
7979
concept.lower() in node_themes for concept in combination

src/ragas/testset/synthesizers/multi_hop/specific.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88

99
from ragas.prompt import PydanticPrompt
10-
from ragas.testset.graph import KnowledgeGraph
10+
from ragas.testset.graph import KnowledgeGraph, Node
1111
from ragas.testset.persona import Persona, PersonaList
1212
from ragas.testset.synthesizers.multi_hop.base import (
1313
MultiHopQuerySynthesizer,
@@ -38,9 +38,26 @@ class MultiHopSpecificQuerySynthesizer(MultiHopQuerySynthesizer):
3838
"""
3939

4040
name: str = "multi_hop_specific_query_synthesizer"
41+
relation_type: str = "entities_overlap"
42+
property_name: str = "entities"
4143
theme_persona_matching_prompt: PydanticPrompt = ThemesPersonasMatchingPrompt()
4244
generate_query_reference_prompt: PydanticPrompt = QueryAnswerGenerationPrompt()
4345

46+
def get_node_clusters(self, knowledge_graph: KnowledgeGraph) -> t.List[t.Set[Node]]:
47+
48+
cluster_dict = knowledge_graph.find_direct_clusters(
49+
relationship_condition=lambda rel: (
50+
True if rel.type == self.relation_type else False
51+
)
52+
)
53+
logger.info("found %d clusters", len(cluster_dict))
54+
node_clusters = []
55+
for key_node, list_of_nodes in cluster_dict.items():
56+
for node in list_of_nodes:
57+
node_clusters.append((key_node, node))
58+
59+
return node_clusters
60+
4461
async def _generate_scenarios(
4562
self,
4663
n: int,
@@ -61,26 +78,21 @@ async def _generate_scenarios(
6178
4. Return the list of scenarios of length n
6279
"""
6380

64-
cluster_dict = knowledge_graph.find_direct_clusters(
65-
relationship_condition=lambda rel: (
66-
True if rel.type == "entities_overlap" else False
81+
node_clusters = self.get_node_clusters(knowledge_graph)
82+
83+
if len(node_clusters) == 0:
84+
raise ValueError(
85+
"No clusters found in the knowledge graph. Try changing the relationship condition."
6786
)
68-
)
87+
88+
num_sample_per_cluster = int(np.ceil(n / len(node_clusters)))
6989

7090
valid_relationships = [
7191
rel
7292
for rel in knowledge_graph.relationships
73-
if rel.type == "entities_overlap"
93+
if rel.type == self.relation_type
7494
]
75-
76-
node_clusters = []
77-
for key_node, list_of_nodes in cluster_dict.items():
78-
for node in list_of_nodes:
79-
node_clusters.append((key_node, node))
80-
81-
logger.info("found %d clusters", len(cluster_dict))
8295
scenarios = []
83-
num_sample_per_cluster = int(np.ceil(n / len(node_clusters)))
8496

8597
for cluster in node_clusters:
8698
if len(scenarios) < n:
@@ -106,7 +118,7 @@ async def _generate_scenarios(
106118
overlapped_items,
107119
PersonaList(personas=persona_list),
108120
persona_concepts,
109-
property_name="entities",
121+
property_name=self.property_name,
110122
)
111123
base_scenarios = self.sample_diverse_combinations(
112124
base_scenarios, num_sample_per_cluster

src/ragas/testset/synthesizers/single_hop/specific.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
import logging
44
import typing as t
5+
from collections import defaultdict
56
from dataclasses import dataclass
67

78
import numpy as np
89

910
from ragas.prompt import PydanticPrompt
10-
from ragas.testset.graph import KnowledgeGraph
11+
from ragas.testset.graph import KnowledgeGraph, Node
1112
from ragas.testset.persona import Persona, PersonaList
1213
from ragas.testset.synthesizers.base import BaseScenario
1314
from ragas.testset.synthesizers.prompts import (
@@ -40,6 +41,37 @@ class SingleHopScenario(BaseScenario):
4041
class SingleHopSpecificQuerySynthesizer(SingleHopQuerySynthesizer):
4142
name: str = "single_hop_specifc_query_synthesizer"
4243
theme_persona_matching_prompt: PydanticPrompt = ThemesPersonasMatchingPrompt()
44+
property_name: str = "entities"
45+
46+
def get_node_clusters(self, knowledge_graph: KnowledgeGraph) -> t.List[Node]:
47+
48+
node_type_dict = defaultdict(int)
49+
for node in knowledge_graph.nodes:
50+
if (
51+
node.type.name == "CHUNK"
52+
and node.get_property(self.property_name) is not None
53+
):
54+
node_type_dict["CHUNK"] += 1
55+
elif (
56+
node.type.name == "DOCUMENT"
57+
and node.get_property(self.property_name) is not None
58+
):
59+
node_type_dict["DOCUMENT"] += 1
60+
else:
61+
pass
62+
63+
node_filter = (
64+
"CHUNK"
65+
if node_type_dict["CHUNK"] > node_type_dict["DOCUMENT"]
66+
else "DOCUMENT"
67+
)
68+
69+
nodes = []
70+
for node in knowledge_graph.nodes:
71+
if node.type.name == node_filter:
72+
nodes.append(node)
73+
74+
return nodes
4375

4476
async def _generate_scenarios(
4577
self,
@@ -61,15 +93,7 @@ async def _generate_scenarios(
6193
4. Return the list of scenarios
6294
"""
6395

64-
property_name = "entities"
65-
nodes = []
66-
for node in knowledge_graph.nodes:
67-
if (
68-
node.type.name == "CHUNK"
69-
and node.get_property(property_name) is not None
70-
):
71-
nodes.append(node)
72-
96+
nodes = self.get_node_clusters(knowledge_graph)
7397
if len(nodes) == 0:
7498
raise ValueError("No nodes found with the `entities` property.")
7599
samples_per_node = int(np.ceil(n / len(nodes)))
@@ -78,7 +102,7 @@ async def _generate_scenarios(
78102
for node in nodes:
79103
if len(scenarios) >= n:
80104
break
81-
themes = node.get_property(property_name)
105+
themes = node.properties.get(self.property_name, [""])
82106
prompt_input = ThemesPersonasInput(themes=themes, personas=persona_list)
83107
persona_concepts = await self.theme_persona_matching_prompt.generate(
84108
data=prompt_input, llm=self.llm, callbacks=callbacks

0 commit comments

Comments
 (0)