33import logging
44import typing as t
55from abc import abstractmethod
6- from collections import namedtuple
76from dataclasses import dataclass , field
87
98from fsspec .exceptions import asyncio
9+ from langchain_core .pydantic_v1 import BaseModel
1010from numpy .random import default_rng
1111
12+ from ragas .llms import BaseRagasLLM
13+ from ragas .llms .prompt import Prompt
1214from ragas .testset .docstore import Direction , DocumentStore , Node
1315from ragas .testset .filters import EvolutionFilter , NodeFilter , QuestionFilter
1416from ragas .testset .prompts import (
2224rng = default_rng ()
2325logger = logging .getLogger (__name__ )
2426
25- if t .TYPE_CHECKING :
26- from ragas .llms import BaseRagasLLM
27- from ragas .llms .prompt import Prompt
28-
2927
3028@dataclass
3129class CurrentNodes :
3230 root_node : Node
3331 nodes : t .List [Node ] = field (default_factory = list )
3432
3533
36- DataRow = namedtuple (
37- "DataRow" ,
38- [
39- "question" ,
40- "context" ,
41- "answer" ,
42- "question_type" ,
43- "evolution_elimination" ,
44- ],
45- )
34+ # (question, current_nodes, evolution_type)
35+ EvolutionOutput = t .Tuple [str , CurrentNodes , str ]
36+
37+
38+ class DataRow (BaseModel ):
39+ question : str
40+ context : str
41+ answer : str
42+ evolution_type : str
4643
4744
4845@dataclass
4946class Evolution :
50- generator_llm : t . Optional [ BaseRagasLLM ] = None
47+ generator_llm : BaseRagasLLM = t . cast ( BaseRagasLLM , None )
5148 docstore : t .Optional [DocumentStore ] = None
5249 node_filter : t .Optional [NodeFilter ] = None
5350 question_filter : t .Optional [QuestionFilter ] = None
@@ -61,7 +58,7 @@ def merge_nodes(nodes: CurrentNodes) -> Node:
6158
6259 async def aretry_evolve (
6360 self , current_tries : int , current_nodes : CurrentNodes , update_count : bool = True
64- ) -> str :
61+ ) -> EvolutionOutput :
6562 if update_count :
6663 current_tries += 1
6764 logger .info ("retrying evolution: %s times" , current_tries )
@@ -112,22 +109,29 @@ def evolve(self, current_nodes: CurrentNodes) -> DataRow:
112109 async def aevolve (self , current_nodes : CurrentNodes ) -> DataRow :
113110 # init tries with 0 when first called
114111 current_tries = 0
115- evolved_question = await self ._aevolve (current_tries , current_nodes )
112+ (
113+ evolved_question ,
114+ current_nodes ,
115+ evolution_type ,
116+ ) = await self ._aevolve (current_tries , current_nodes )
117+
116118 return self .generate_datarow (
117119 question = evolved_question ,
118120 current_nodes = current_nodes ,
121+ evolution_type = evolution_type ,
119122 )
120123
121124 @abstractmethod
122- async def _aevolve (self , current_tries : int , current_nodes : CurrentNodes ) -> str :
125+ async def _aevolve (
126+ self , current_tries : int , current_nodes : CurrentNodes
127+ ) -> EvolutionOutput :
123128 ...
124129
125130 def generate_datarow (
126131 self ,
127132 question : str ,
128133 current_nodes : CurrentNodes ,
129- question_type : str = "" ,
130- evolution_elimination : bool = False ,
134+ evolution_type : str ,
131135 ):
132136 assert self .generator_llm is not None , "generator_llm cannot be None"
133137
@@ -146,15 +150,16 @@ def generate_datarow(
146150 return DataRow (
147151 question = question ,
148152 context = merged_nodes .page_content ,
149- answer = answer ,
150- question_type = question_type ,
151- evolution_elimination = evolution_elimination ,
153+ answer = "" if answer is None else answer ,
154+ evolution_type = evolution_type ,
152155 )
153156
154157
155158@dataclass
156159class SimpleEvolution (Evolution ):
157- async def _aevolve (self , current_tries : int , current_nodes : CurrentNodes ) -> str :
160+ async def _aevolve (
161+ self , current_tries : int , current_nodes : CurrentNodes
162+ ) -> EvolutionOutput :
158163 assert self .docstore is not None , "docstore cannot be None"
159164 assert self .node_filter is not None , "node filter cannot be None"
160165 assert self .generator_llm is not None , "generator_llm cannot be None"
@@ -183,7 +188,7 @@ async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str
183188 return await self .aretry_evolve (current_tries , current_nodes )
184189 else :
185190 # if valid question
186- return seed_question
191+ return seed_question , current_nodes , "simple"
187192
188193 def __hash__ (self ):
189194 return hash (self .__class__ .__name__ )
@@ -209,13 +214,15 @@ def init_evolution(self):
209214
210215@dataclass
211216class MultiContextEvolution (ComplexEvolution ):
212- async def _aevolve (self , current_tries : int , current_nodes : CurrentNodes ) -> str :
217+ async def _aevolve (
218+ self , current_tries : int , current_nodes : CurrentNodes
219+ ) -> EvolutionOutput :
213220 assert self .docstore is not None , "docstore cannot be None"
214221 assert self .generator_llm is not None , "generator_llm cannot be None"
215222 assert self .question_filter is not None , "question_filter cannot be None"
216223 assert self .se is not None , "simple evolution cannot be None"
217224
218- simple_question = await self .se ._aevolve (current_tries , current_nodes )
225+ simple_question , _ , _ = await self .se ._aevolve (current_tries , current_nodes )
219226 logger .debug (
220227 "[MultiContextEvolution] simple question generated: %s" , simple_question
221228 )
@@ -254,20 +261,22 @@ async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str
254261 current_nodes = self .se ._get_more_adjacent_nodes (current_nodes )
255262 return await self .aretry_evolve (current_tries , current_nodes )
256263
257- return compressed_question
264+ return compressed_question , current_nodes , "multi_context"
258265
259266 def __hash__ (self ):
260267 return hash (self .__class__ .__name__ )
261268
262269
263270@dataclass
264271class ReasoningEvolution (ComplexEvolution ):
265- async def _aevolve (self , current_tries : int , current_nodes : CurrentNodes ) -> str :
272+ async def _aevolve (
273+ self , current_tries : int , current_nodes : CurrentNodes
274+ ) -> EvolutionOutput :
266275 assert self .generator_llm is not None , "generator_llm cannot be None"
267276 assert self .question_filter is not None , "question_filter cannot be None"
268277 assert self .se is not None , "simple evolution cannot be None"
269278
270- simple_question = await self .se ._aevolve (current_tries , current_nodes )
279+ simple_question , _ , _ = await self .se ._aevolve (current_tries , current_nodes )
271280 logger .debug (
272281 "[ReasoningEvolution] simple question generated: %s" , simple_question
273282 )
@@ -304,7 +313,7 @@ async def _aevolve(self, current_tries: int, current_nodes: CurrentNodes) -> str
304313 )
305314 return await self .aretry_evolve (current_tries , current_nodes )
306315
307- return reasoning_question
316+ return reasoning_question , current_nodes , "reasoning"
308317
309318 def __hash__ (self ):
310319 return hash (self .__class__ .__name__ )
0 commit comments