Skip to content

Commit 96c2952

Browse files
authored
feat: improvements in test gen (#1645)
1. Add node filtering mechanism 2. Add min and max tokens in headline splitting
1 parent 0415a2d commit 96c2952

File tree

8 files changed

+318
-45
lines changed

8 files changed

+318
-45
lines changed

src/ragas/testset/graph.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import typing as t
33
import uuid
4+
from copy import deepcopy
45
from dataclasses import dataclass, field
56
from enum import Enum
67
from pathlib import Path
@@ -268,6 +269,53 @@ def dfs(node: Node, cluster: t.Set[Node], depth: int, path: t.Tuple[Node, ...]):
268269

269270
return unique_clusters
270271

272+
def remove_node(
273+
self, node: Node, inplace: bool = True
274+
) -> t.Optional["KnowledgeGraph"]:
275+
"""
276+
Removes a node and its associated relationships from the knowledge graph.
277+
278+
Parameters
279+
----------
280+
node : Node
281+
The node to be removed from the knowledge graph.
282+
inplace : bool, optional
283+
If True, modifies the knowledge graph in place.
284+
If False, returns a modified copy with the node removed.
285+
286+
Returns
287+
-------
288+
KnowledgeGraph or None
289+
Returns a modified copy of the knowledge graph if `inplace` is False.
290+
Returns None if `inplace` is True.
291+
292+
Raises
293+
------
294+
ValueError
295+
If the node is not present in the knowledge graph.
296+
"""
297+
if node not in self.nodes:
298+
raise ValueError("Node is not present in the knowledge graph.")
299+
300+
if inplace:
301+
# Modify the current instance
302+
self.nodes.remove(node)
303+
self.relationships = [
304+
rel
305+
for rel in self.relationships
306+
if rel.source != node and rel.target != node
307+
]
308+
else:
309+
# Create a deep copy and modify it
310+
new_graph = deepcopy(self)
311+
new_graph.nodes.remove(node)
312+
new_graph.relationships = [
313+
rel
314+
for rel in new_graph.relationships
315+
if rel.source != node and rel.target != node
316+
]
317+
return new_graph
318+
271319
def find_direct_clusters(
272320
self, relationship_condition: t.Callable[[Relationship], bool] = lambda _: True
273321
) -> t.Dict[Node, t.List[t.Set[Node]]]:

src/ragas/testset/graph_queries.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,38 @@ def dfs(current_node: Node, current_level: int):
3636
dfs(node, 1)
3737

3838
return children
39+
40+
41+
def get_parent_nodes(node: Node, graph: KnowledgeGraph, level: int = 1) -> t.List[Node]:
42+
"""
43+
Get the parent nodes of a given node up to a specified level.
44+
45+
Parameters
46+
----------
47+
node : Node
48+
The node to get the parents of.
49+
graph : KnowledgeGraph
50+
The knowledge graph containing the node.
51+
level : int
52+
The maximum level to which parent nodes are searched.
53+
54+
Returns
55+
-------
56+
List[Node]
57+
The list of parent nodes up to the specified level.
58+
"""
59+
parents = []
60+
61+
# Helper function to perform depth-limited search for parent nodes
62+
def dfs(current_node: Node, current_level: int):
63+
if current_level > level:
64+
return
65+
for rel in graph.relationships:
66+
if rel.target == current_node and rel.type == "child":
67+
parents.append(rel.source)
68+
dfs(rel.source, current_level + 1)
69+
70+
# Start DFS from the initial node at level 0
71+
dfs(node, 1)
72+
73+
return parents

src/ragas/testset/transforms/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .base import BaseGraphTransformation, Extractor, RelationshipBuilder, Splitter
1+
from .base import BaseGraphTransformation, Extractor, RelationshipBuilder, Splitter, NodeFilter
22
from .default import default_transforms
33
from .engine import Parallel, Transforms, apply_transforms, rollback_transforms
44
from .extractors import (
@@ -13,6 +13,7 @@
1313
SummaryCosineSimilarityBuilder,
1414
)
1515
from .splitters import HeadlineSplitter
16+
from .filters import CustomNodeFilter
1617

1718
__all__ = [
1819
# base
@@ -37,4 +38,6 @@
3738
"SummaryCosineSimilarityBuilder",
3839
# splitters
3940
"HeadlineSplitter",
41+
"CustomNodeFilter",
42+
"NodeFilter",
4043
]

src/ragas/testset/transforms/base.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,3 +322,55 @@ async def apply_build_relationships(
322322

323323
filtered_kg = self.filter(kg)
324324
return [apply_build_relationships(filtered_kg=filtered_kg, original_kg=kg)]
325+
326+
327+
@dataclass
328+
class NodeFilter(BaseGraphTransformation):
329+
330+
async def transform(self, kg: KnowledgeGraph) -> KnowledgeGraph:
331+
332+
filtered = self.filter(kg)
333+
334+
for node in filtered.nodes:
335+
flag = await self.custom_filter(node, kg)
336+
if flag:
337+
kg_ = kg.remove_node(node, inplace=False)
338+
if isinstance(kg_, KnowledgeGraph):
339+
return kg_
340+
else:
341+
raise ValueError("Error in removing node")
342+
return kg
343+
344+
@abstractmethod
345+
async def custom_filter(self, node: Node, kg: KnowledgeGraph) -> bool:
346+
"""
347+
Abstract method to filter a node based on a prompt.
348+
349+
Parameters
350+
----------
351+
node : Node
352+
The node to be filtered.
353+
354+
Returns
355+
-------
356+
bool
357+
A boolean indicating whether the node should be filtered.
358+
"""
359+
pass
360+
361+
def generate_execution_plan(self, kg: KnowledgeGraph) -> t.List[t.Coroutine]:
362+
"""
363+
Generates a list of coroutines to be executed
364+
"""
365+
366+
async def apply_filter(node: Node):
367+
if await self.custom_filter(node, kg):
368+
kg.remove_node(node)
369+
370+
filtered = self.filter(kg)
371+
return [apply_filter(node) for node in filtered.nodes]
372+
373+
374+
@dataclass
375+
class LLMBasedNodeFilter(NodeFilter, PromptMixin):
376+
llm: BaseRagasLLM = field(default_factory=llm_factory)

src/ragas/testset/transforms/default.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
SummaryExtractor,
1010
)
1111
from ragas.testset.transforms.extractors.llm_based import NERExtractor, ThemesExtractor
12+
from ragas.testset.transforms.filters import CustomNodeFilter
1213
from ragas.testset.transforms.relationship_builders import (
1314
CosineSimilarityBuilder,
1415
OverlapScoreBuilder,
@@ -82,11 +83,14 @@ def summary_filter(node):
8283
threshold=0.01, filter_nodes=lambda node: node.type == NodeType.CHUNK
8384
)
8485

86+
node_filter = CustomNodeFilter(llm=llm, filter_nodes=lambda node: node.type == NodeType.CHUNK)
87+
8588
transforms = [
8689
headline_extractor,
8790
splitter,
88-
Parallel(summary_extractor, theme_extractor, ner_extractor),
89-
summary_emb_extractor,
91+
summary_extractor,
92+
node_filter,
93+
Parallel(summary_emb_extractor, theme_extractor, ner_extractor),
9094
Parallel(cosine_sim_builder, ner_overlap_sim),
9195
]
9296

src/ragas/testset/transforms/extractors/llm_based.py

Lines changed: 57 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -70,35 +70,44 @@ class Headlines(BaseModel):
7070

7171

7272
class HeadlinesExtractorPrompt(PydanticPrompt[StringIO, Headlines]):
73-
instruction: str = "Extract only level 2 headings from the given text."
73+
instruction: str = "Extract only level 2 and level 3 headings from the given text."
7474

7575
input_model: t.Type[StringIO] = StringIO
7676
output_model: t.Type[Headlines] = Headlines
7777
examples: t.List[t.Tuple[StringIO, Headlines]] = [
7878
(
7979
StringIO(
8080
text="""\
81-
Introduction
82-
Overview of the topic...
81+
Introduction
82+
Overview of the topic...
8383
84-
Main Concepts
85-
Explanation of core ideas...
84+
Main Concepts
85+
Explanation of core ideas...
8686
87-
Detailed Analysis
88-
Techniques and methods for analysis...
87+
Detailed Analysis
88+
Techniques and methods for analysis...
8989
90-
Subsection: Specialized Techniques
91-
Further details on specialized techniques...
90+
Subsection: Specialized Techniques
91+
Further details on specialized techniques...
9292
93-
Future Directions
94-
Insights into upcoming trends...
93+
Future Directions
94+
Insights into upcoming trends...
9595
96-
Conclusion
97-
Final remarks and summary.
98-
""",
96+
Subsection: Next Steps in Research
97+
Discussion of new areas of study...
98+
99+
Conclusion
100+
Final remarks and summary.
101+
"""
99102
),
100103
Headlines(
101-
headlines=["Main Concepts", "Detailed Analysis", "Future Directions"]
104+
headlines=[
105+
"Main Concepts",
106+
"Detailed Analysis",
107+
"Subsection: Specialized Techniques",
108+
"Future Directions",
109+
"Subsection: Next Steps in Research",
110+
]
102111
),
103112
),
104113
]
@@ -108,15 +117,24 @@ class NEROutput(BaseModel):
108117
entities: t.List[str]
109118

110119

111-
class NERPrompt(PydanticPrompt[StringIO, NEROutput]):
112-
instruction: str = "Extract named entities from the given text."
113-
input_model: t.Type[StringIO] = StringIO
120+
class TextWithExtractionLimit(BaseModel):
121+
text: str
122+
max_num: int = 10
123+
124+
125+
class NERPrompt(PydanticPrompt[TextWithExtractionLimit, NEROutput]):
126+
instruction: str = (
127+
"Extract the named entities from the given text, limiting the output to the top entities. "
128+
"Ensure the number of entities does not exceed the specified maximum."
129+
)
130+
input_model: t.Type[TextWithExtractionLimit] = TextWithExtractionLimit
114131
output_model: t.Type[NEROutput] = NEROutput
115-
examples: t.List[t.Tuple[StringIO, NEROutput]] = [
132+
examples: t.List[t.Tuple[TextWithExtractionLimit, NEROutput]] = [
116133
(
117-
StringIO(
134+
TextWithExtractionLimit(
118135
text="""Elon Musk, the CEO of Tesla and SpaceX, announced plans to expand operations to new locations in Europe and Asia.
119-
This expansion is expected to create thousands of jobs, particularly in cities like Berlin and Shanghai."""
136+
This expansion is expected to create thousands of jobs, particularly in cities like Berlin and Shanghai.""",
137+
max_num=10,
120138
),
121139
NEROutput(
122140
entities=[
@@ -246,12 +264,16 @@ class NERExtractor(LLMBasedExtractor):
246264

247265
property_name: str = "entities"
248266
prompt: NERPrompt = NERPrompt()
267+
max_num_entities: int = 10
249268

250269
async def extract(self, node: Node) -> t.Tuple[str, t.List[str]]:
251270
node_text = node.get_property("page_content")
252271
if node_text is None:
253272
return self.property_name, []
254-
result = await self.prompt.generate(self.llm, data=StringIO(text=node_text))
273+
result = await self.prompt.generate(
274+
self.llm,
275+
data=TextWithExtractionLimit(text=node_text, max_num=self.max_num_entities),
276+
)
255277
return self.property_name, result.entities
256278

257279

@@ -305,14 +327,17 @@ class ThemesAndConcepts(BaseModel):
305327
output: t.List[str]
306328

307329

308-
class ThemesAndConceptsExtractorPrompt(PydanticPrompt[StringIO, ThemesAndConcepts]):
330+
class ThemesAndConceptsExtractorPrompt(
331+
PydanticPrompt[TextWithExtractionLimit, ThemesAndConcepts]
332+
):
309333
instruction: str = "Extract the main themes and concepts from the given text."
310-
input_model: t.Type[StringIO] = StringIO
334+
input_model: t.Type[TextWithExtractionLimit] = TextWithExtractionLimit
311335
output_model: t.Type[ThemesAndConcepts] = ThemesAndConcepts
312-
examples: t.List[t.Tuple[StringIO, ThemesAndConcepts]] = [
336+
examples: t.List[t.Tuple[TextWithExtractionLimit, ThemesAndConcepts]] = [
313337
(
314-
StringIO(
315-
text="Artificial intelligence is transforming industries by automating tasks requiring human intelligence. AI analyzes vast data quickly and accurately, driving innovations like self-driving cars and personalized recommendations."
338+
TextWithExtractionLimit(
339+
text="Artificial intelligence is transforming industries by automating tasks requiring human intelligence. AI analyzes vast data quickly and accurately, driving innovations like self-driving cars and personalized recommendations.",
340+
max_num=10,
316341
),
317342
ThemesAndConcepts(
318343
output=[
@@ -343,10 +368,14 @@ class ThemesExtractor(LLMBasedExtractor):
343368

344369
property_name: str = "themes"
345370
prompt: ThemesAndConceptsExtractorPrompt = ThemesAndConceptsExtractorPrompt()
371+
max_num_themes: int = 10
346372

347373
async def extract(self, node: Node) -> t.Tuple[str, t.List[str]]:
348374
node_text = node.get_property("page_content")
349375
if node_text is None:
350376
return self.property_name, []
351-
result = await self.prompt.generate(self.llm, data=StringIO(text=node_text))
377+
result = await self.prompt.generate(
378+
self.llm,
379+
data=TextWithExtractionLimit(text=node_text, max_num=self.max_num_themes),
380+
)
352381
return self.property_name, result.output

0 commit comments

Comments
 (0)