Skip to content

Commit b113f52

Browse files
authored
fix: handle tuple-formatted entities in SingleHopSpecificQuerySynthesizer (#2377)
## Issue Link / Problem Description <!-- Link to related issue or describe the problem this PR solves --> - Fixes #2368 ## Changes Made <!-- Describe what you changed and why --> - helper method `_extract_themes_from_items` to handle various formats ## Testing <!-- Describe how this should be tested --> ### How to Test - [x] Automated tests added/updated
1 parent 34b4733 commit b113f52

File tree

2 files changed

+193
-1
lines changed

2 files changed

+193
-1
lines changed

src/ragas/testset/synthesizers/single_hop/specific.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,44 @@ class SingleHopSpecificQuerySynthesizer(SingleHopQuerySynthesizer):
4343
theme_persona_matching_prompt: PydanticPrompt = ThemesPersonasMatchingPrompt()
4444
property_name: str = "entities"
4545

46+
def _extract_themes_from_items(self, items: t.Any) -> t.List[str]:
47+
"""
48+
Extract unique theme names from various formats.
49+
50+
Handles multiple data formats that might appear during synthesis:
51+
- List[Tuple[str, str]]: Entity pairs (from overlap detection)
52+
- List[List[str]]: Entity pairs as lists
53+
- List[str]: Direct entity names
54+
- Dict[str, Any]: Keys as entity names
55+
56+
Parameters
57+
----------
58+
items : t.Any
59+
The items to extract themes from.
60+
61+
Returns
62+
-------
63+
t.List[str]
64+
List of unique theme strings.
65+
"""
66+
if isinstance(items, dict):
67+
return list(items.keys())
68+
69+
if not isinstance(items, list):
70+
return []
71+
72+
unique_themes = set()
73+
for item in items:
74+
if isinstance(item, (tuple, list)):
75+
# Extract strings from pairs/sequences
76+
for element in item:
77+
if isinstance(element, str):
78+
unique_themes.add(element)
79+
elif isinstance(item, str):
80+
unique_themes.add(item)
81+
82+
return list(unique_themes)
83+
4684
def get_node_clusters(self, knowledge_graph: KnowledgeGraph) -> t.List[Node]:
4785
node_type_dict = defaultdict(int)
4886
for node in knowledge_graph.nodes:
@@ -101,7 +139,14 @@ async def _generate_scenarios(
101139
for node in nodes:
102140
if len(scenarios) >= n:
103141
break
104-
themes = node.properties.get(self.property_name, [""])
142+
raw_themes = node.properties.get(self.property_name, [])
143+
# Extract themes from potentially mixed data types (handles tuples, lists, strings)
144+
themes = self._extract_themes_from_items(raw_themes)
145+
146+
if not themes: # Skip if no themes extracted
147+
logger.debug("No themes extracted from node %s. Skipping.", node.id)
148+
continue
149+
105150
prompt_input = ThemesPersonasInput(themes=themes, personas=persona_list)
106151
persona_concepts = await self.theme_persona_matching_prompt.generate(
107152
data=prompt_input, llm=self.llm, callbacks=callbacks
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import typing as t
2+
3+
import pytest
4+
5+
from ragas.prompt import PydanticPrompt
6+
from ragas.testset.graph import KnowledgeGraph, Node, NodeType
7+
from ragas.testset.persona import Persona
8+
from ragas.testset.synthesizers.prompts import PersonaThemesMapping, ThemesPersonasInput
9+
from ragas.testset.synthesizers.single_hop.specific import (
10+
SingleHopSpecificQuerySynthesizer,
11+
)
12+
13+
14+
class MockThemePersonaMatchingPrompt(PydanticPrompt):
15+
async def generate(self, data: ThemesPersonasInput, llm, callbacks=None):
16+
themes: t.List[str] = data.themes
17+
personas: t.List[Persona] = data.personas
18+
return PersonaThemesMapping(
19+
mapping={persona.name: themes for persona in personas}
20+
)
21+
22+
23+
def test_extract_themes_from_items_with_strings(fake_llm):
24+
"""Test _extract_themes_from_items with string input."""
25+
synthesizer = SingleHopSpecificQuerySynthesizer(llm=fake_llm)
26+
27+
items = ["Theme1", "Theme2", "Theme3"]
28+
themes = synthesizer._extract_themes_from_items(items)
29+
30+
assert set(themes) == {"Theme1", "Theme2", "Theme3"}
31+
32+
33+
def test_extract_themes_from_items_with_tuples(fake_llm):
34+
"""Test _extract_themes_from_items with tuple input (the bug fix)."""
35+
synthesizer = SingleHopSpecificQuerySynthesizer(llm=fake_llm)
36+
37+
# This is the format that was causing the ValidationError in issue #2368
38+
items = [("Entity1", "Entity1"), ("Entity2", "Entity2")]
39+
themes = synthesizer._extract_themes_from_items(items)
40+
41+
assert set(themes) == {"Entity1", "Entity2"}
42+
43+
44+
def test_extract_themes_from_items_with_mixed_formats(fake_llm):
45+
"""Test _extract_themes_from_items with mixed formats."""
46+
synthesizer = SingleHopSpecificQuerySynthesizer(llm=fake_llm)
47+
48+
items = ["Theme1", ("Entity2", "Entity2"), ["Entity3", "Entity3"]]
49+
themes = synthesizer._extract_themes_from_items(items)
50+
51+
assert set(themes) == {"Theme1", "Entity2", "Entity3"}
52+
53+
54+
def test_extract_themes_from_items_with_dict(fake_llm):
55+
"""Test _extract_themes_from_items with dict input."""
56+
synthesizer = SingleHopSpecificQuerySynthesizer(llm=fake_llm)
57+
58+
items = {"Theme1": "value1", "Theme2": "value2"}
59+
themes = synthesizer._extract_themes_from_items(items)
60+
61+
assert set(themes) == {"Theme1", "Theme2"}
62+
63+
64+
def test_extract_themes_from_items_empty_input(fake_llm):
65+
"""Test _extract_themes_from_items with empty input."""
66+
synthesizer = SingleHopSpecificQuerySynthesizer(llm=fake_llm)
67+
68+
assert synthesizer._extract_themes_from_items([]) == []
69+
assert synthesizer._extract_themes_from_items(None) == []
70+
assert synthesizer._extract_themes_from_items("invalid") == []
71+
72+
73+
def test_extract_themes_from_items_with_nested_empty_tuples(fake_llm):
74+
"""Test _extract_themes_from_items skips non-string elements."""
75+
synthesizer = SingleHopSpecificQuerySynthesizer(llm=fake_llm)
76+
77+
items = [("Theme1", 123), (456, "Theme2"), ("Theme3", "Theme3")]
78+
themes = synthesizer._extract_themes_from_items(items)
79+
80+
# Only string elements should be extracted
81+
assert set(themes) == {"Theme1", "Theme2", "Theme3"}
82+
83+
84+
@pytest.mark.asyncio
85+
async def test_generate_scenarios_with_tuple_entities(fake_llm):
86+
"""Test that _generate_scenarios handles tuple-formatted entities correctly.
87+
88+
This test validates the fix for issue #2368 where entities property
89+
containing tuples would cause ValidationError.
90+
"""
91+
# Create a node with tuple-formatted entities (the problematic case)
92+
node = Node(type=NodeType.CHUNK)
93+
node.add_property("entities", [("Entity1", "Entity1"), ("Entity2", "Entity2")])
94+
95+
kg = KnowledgeGraph(nodes=[node])
96+
97+
personas = [
98+
Persona(
99+
name="Researcher",
100+
role_description="A researcher interested in entities.",
101+
),
102+
]
103+
104+
synthesizer = SingleHopSpecificQuerySynthesizer(llm=fake_llm)
105+
synthesizer.theme_persona_matching_prompt = MockThemePersonaMatchingPrompt()
106+
107+
# This should not raise ValidationError
108+
scenarios = await synthesizer._generate_scenarios(
109+
n=2,
110+
knowledge_graph=kg,
111+
persona_list=personas,
112+
callbacks=None,
113+
)
114+
115+
# Should generate scenarios successfully
116+
assert len(scenarios) > 0
117+
118+
119+
@pytest.mark.asyncio
120+
async def test_generate_scenarios_with_string_entities(fake_llm):
121+
"""Test that _generate_scenarios still works with string-formatted entities."""
122+
# Create a node with string-formatted entities (the normal case)
123+
node = Node(type=NodeType.CHUNK)
124+
node.add_property("entities", ["Entity1", "Entity2", "Entity3"])
125+
126+
kg = KnowledgeGraph(nodes=[node])
127+
128+
personas = [
129+
Persona(
130+
name="Researcher",
131+
role_description="A researcher interested in entities.",
132+
),
133+
]
134+
135+
synthesizer = SingleHopSpecificQuerySynthesizer(llm=fake_llm)
136+
synthesizer.theme_persona_matching_prompt = MockThemePersonaMatchingPrompt()
137+
138+
# This should work as before
139+
scenarios = await synthesizer._generate_scenarios(
140+
n=2,
141+
knowledge_graph=kg,
142+
persona_list=personas,
143+
callbacks=None,
144+
)
145+
146+
# Should generate scenarios successfully
147+
assert len(scenarios) > 0

0 commit comments

Comments
 (0)