Skip to content

Commit 366cb9f

Browse files
authored
fix: filter and context selection (#672)
1 parent 6aa4b5b commit 366cb9f

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

src/ragas/testset/evolutions.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ async def generate_datarow(
182182
assert self.generator_llm is not None, "generator_llm cannot be None"
183183

184184
node_content = [
185-
f"{i}\t{n.page_content}" for i, n in enumerate(current_nodes.nodes)
185+
f"{i+1}\t{n.page_content}" for i, n in enumerate(current_nodes.nodes)
186186
]
187187
results = await self.generator_llm.generate(
188188
prompt=self.find_relevent_context_prompt.format(
@@ -197,15 +197,20 @@ async def generate_datarow(
197197
if isinstance(relevent_contexts_result, dict)
198198
else None
199199
)
200-
201200
if relevant_context_indices is None:
202201
relevant_context = CurrentNodes(
203202
root_node=current_nodes.root_node, nodes=current_nodes.nodes
204203
)
205204
else:
206-
selected_nodes = [current_nodes.nodes[i] for i in relevant_context_indices]
207-
relevant_context = CurrentNodes(
208-
root_node=selected_nodes[0], nodes=selected_nodes
205+
selected_nodes = [
206+
current_nodes.nodes[i - 1]
207+
for i in relevant_context_indices
208+
if i - 1 < len(current_nodes.nodes)
209+
]
210+
relevant_context = (
211+
CurrentNodes(root_node=selected_nodes[0], nodes=selected_nodes)
212+
if selected_nodes
213+
else current_nodes
209214
)
210215

211216
merged_nodes = self.merge_nodes(relevant_context)
@@ -278,10 +283,9 @@ async def _aevolve(
278283
merged_node = self.merge_nodes(current_nodes)
279284
passed = await self.node_filter.filter(merged_node)
280285
if not passed["score"]:
281-
nodes = self.docstore.get_random_nodes(k=1)
282-
new_current_nodes = CurrentNodes(root_node=nodes[0], nodes=nodes)
286+
current_nodes = self._get_new_random_node()
283287
return await self.aretry_evolve(
284-
current_tries, new_current_nodes, update_count=False
288+
current_tries, current_nodes, update_count=False
285289
)
286290

287291
logger.debug("keyphrases in merged node: %s", merged_node.keyphrases)
@@ -400,7 +404,7 @@ async def _acomplex_evolution(
400404
)
401405

402406
assert self.evolution_filter is not None, "evolution filter cannot be None"
403-
if not await self.evolution_filter.filter(simple_question, compressed_question):
407+
if await self.evolution_filter.filter(simple_question, compressed_question):
404408
# retry
405409
current_nodes = self.se._get_new_random_node()
406410
logger.debug(
@@ -500,7 +504,7 @@ async def _aevolve(
500504
)
501505

502506
assert self.evolution_filter is not None, "evolution filter cannot be None"
503-
if not await self.evolution_filter.filter(simple_question, compressed_question):
507+
if await self.evolution_filter.filter(simple_question, compressed_question):
504508
# retry
505509
current_nodes = self.se._get_new_random_node()
506510
return await self.aretry_evolve(current_tries, current_nodes)

0 commit comments

Comments
 (0)