Skip to content

Commit fe0bcc4

Browse files
authored
Fix: evolution flows for test generation (#602)
fixes : #599
1 parent 9149d20 commit fe0bcc4

File tree

1 file changed

+46
-18
lines changed

1 file changed

+46
-18
lines changed

src/ragas/testset/evolutions.py

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,23 @@ class Evolution:
6363

6464
@staticmethod
6565
def merge_nodes(nodes: CurrentNodes) -> Node:
66-
return Node(
66+
# TODO: while merging merge according to the order of documents
67+
# if any nodes from same document take account their page order
68+
69+
new_node = Node(
6770
doc_id="merged",
6871
page_content="\n".join(n.page_content for n in nodes.nodes),
6972
keyphrases=[phrase for n in nodes.nodes for phrase in n.keyphrases],
7073
)
7174

75+
embed_dim = len(nodes.nodes[0].embedding) if nodes.nodes[0].embedding else None
76+
if embed_dim:
77+
node_embeddings = np.array([n.embedding for n in nodes.nodes]).reshape(
78+
-1, embed_dim
79+
)
80+
new_node.embedding = np.average(node_embeddings, axis=0)
81+
return new_node
82+
7283
def init(self, is_async: bool = True, run_config: t.Optional[RunConfig] = None):
7384
self.is_async = is_async
7485
if run_config is None:
@@ -191,7 +202,10 @@ async def generate_datarow(
191202
root_node=current_nodes.root_node, nodes=current_nodes.nodes
192203
)
193204
else:
194-
relevant_context = current_nodes
205+
selected_nodes = [current_nodes.nodes[i] for i in relevant_context_indices]
206+
relevant_context = CurrentNodes(
207+
root_node=selected_nodes[0], nodes=selected_nodes
208+
)
195209

196210
merged_nodes = self.merge_nodes(relevant_context)
197211
results = await self.generator_llm.generate(
@@ -207,7 +221,7 @@ async def generate_datarow(
207221

208222
return DataRow(
209223
question=question,
210-
contexts=[n.page_content for n in current_nodes.nodes],
224+
contexts=[n.page_content for n in relevant_context.nodes],
211225
ground_truth="" if answer is None else answer,
212226
evolution_type=evolution_type,
213227
)
@@ -253,7 +267,7 @@ async def _aevolve(
253267
assert self.question_filter is not None, "question_filter cannot be None"
254268

255269
merged_node = self.merge_nodes(current_nodes)
256-
passed = await self.node_filter.filter(current_nodes.root_node)
270+
passed = await self.node_filter.filter(merged_node)
257271
if not passed["score"]:
258272
nodes = self.docstore.get_random_nodes(k=1)
259273
new_current_nodes = CurrentNodes(root_node=nodes[0], nodes=nodes)
@@ -334,16 +348,19 @@ async def _acomplex_evolution(
334348
assert self.question_filter is not None, "question_filter cannot be None"
335349
assert self.se is not None, "simple evolution cannot be None"
336350

337-
simple_question, _, _ = await self.se._aevolve(current_tries, current_nodes)
351+
simple_question, current_nodes, _ = await self.se._aevolve(
352+
current_tries, current_nodes
353+
)
338354
logger.debug(
339355
"[%s] simple question generated: %s",
340356
self.__class__.__name__,
341357
simple_question,
342358
)
343359

360+
merged_node = self.merge_nodes(current_nodes)
344361
result = await self.generator_llm.generate(
345362
prompt=question_prompt.format(
346-
question=simple_question, context=current_nodes.root_node.page_content
363+
question=simple_question, context=merged_node.page_content
347364
)
348365
)
349366
reasoning_question = result.generations[0][0].text.strip()
@@ -409,42 +426,53 @@ async def _aevolve(
409426
assert self.question_filter is not None, "question_filter cannot be None"
410427
assert self.se is not None, "simple evolution cannot be None"
411428

412-
simple_question, _, _ = await self.se._aevolve(current_tries, current_nodes)
429+
simple_question, current_nodes, _ = await self.se._aevolve(
430+
current_tries, current_nodes
431+
)
413432
logger.debug(
414433
"[MultiContextEvolution] simple question generated: %s", simple_question
415434
)
416-
417435
# find a similar node and generate a question based on both
418-
similar_node = self.docstore.get_similar(current_nodes.root_node)
436+
merged_node = self.merge_nodes(current_nodes)
437+
similar_node = self.docstore.get_similar(merged_node, top_k=1)
419438
if similar_node == []:
420439
# retry
421-
current_nodes = self.se._get_more_adjacent_nodes(current_nodes)
440+
new_random_nodes = self.docstore.get_random_nodes(k=1)
441+
current_nodes = CurrentNodes(
442+
root_node=new_random_nodes[0], nodes=new_random_nodes
443+
)
422444
return await self.aretry_evolve(current_tries, current_nodes)
445+
else:
446+
assert isinstance(similar_node[0], Node), "similar_node must be a Node"
447+
current_nodes = CurrentNodes(
448+
root_node=merged_node, nodes=[merged_node, similar_node[0]]
449+
)
423450

424451
prompt = self.multi_context_question_prompt.format(
425452
question=simple_question,
426-
context1=current_nodes.root_node.page_content,
427-
context2=similar_node,
453+
context1=merged_node.page_content,
454+
context2=similar_node[0].page_content,
428455
)
429456
results = await self.generator_llm.generate(prompt=prompt)
430457
question = results.generations[0][0].text.strip()
431458
logger.debug(
432459
"[MultiContextEvolution] multicontext question generated: %s", question
433460
)
434461

462+
if not await self.question_filter.filter(question):
463+
# retry
464+
current_nodes = self.se._get_more_adjacent_nodes(current_nodes)
465+
return await self.aretry_evolve(current_tries, current_nodes)
466+
435467
# compress the question
436468
compressed_question = await self._transform_question(
437469
prompt=self.compress_question_prompt, question=question
438470
)
439471
logger.debug(
440-
"[MultiContextEvolution] multicontext question compressed: %s", question
472+
"[MultiContextEvolution] multicontext question compressed: %s",
473+
compressed_question,
441474
)
442475

443-
if not await self.question_filter.filter(compressed_question):
444-
# retry
445-
current_nodes = self.se._get_more_adjacent_nodes(current_nodes)
446-
return await self.aretry_evolve(current_tries, current_nodes)
447-
448476
assert self.evolution_filter is not None, "evolution filter cannot be None"
449477
if await self.evolution_filter.filter(simple_question, compressed_question):
450478
# retry

0 commit comments

Comments
 (0)