1+ import re
12import typing as t
3+ import warnings
24from collections import defaultdict , namedtuple
35from dataclasses import dataclass
46
3335)
3436
3537DEFAULT_TEST_DISTRIBUTION = {
36- "simple" : 0.5 ,
38+ "simple" : 0.4 ,
3739 "reasoning" : 0.2 ,
3840 "multi_context" : 0.2 ,
39- "conditional" : 0.1 ,
41+ "conditional" : 0.2 ,
4042}
4143
4244question_deep_map = {
@@ -106,7 +108,7 @@ def __init__(
106108 critic_llm : BaseLLM | BaseChatModel ,
107109 embeddings_model : Embeddings ,
108110 testset_distribution : t .Optional [t .Dict [str , float ]] = None ,
109- chat_qa : float = 0.3 ,
111+ chat_qa : float = 0.0 ,
110112 chunk_size : int = 1024 ,
111113 seed : int = 42 ,
112114 ) -> None :
@@ -135,7 +137,7 @@ def from_default(
135137 openai_generator_llm : str = "gpt-3.5-turbo-16k" ,
136138 openai_filter_llm : str = "gpt-4" ,
137139 chat_qa : float = 0.3 ,
138- chunk_size : int = 1024 ,
140+ chunk_size : int = 512 ,
139141 ):
140142 generator_llm = ChatOpenAI (model = openai_generator_llm )
141143 critic_llm = ChatOpenAI (model = openai_filter_llm )
@@ -173,14 +175,12 @@ def _filter_context(self, context: str) -> bool:
173175 prompt = ChatPromptTemplate .from_messages ([human_prompt ])
174176 results = generate (prompts = [prompt ], llm = self .critic_llm )
175177 output = results .generations [0 ][0 ].text .strip ()
176- score = eval (output )
177- if not isinstance (score , float | int ):
178- index = output .lower ().find ("score:" )
179- if index != - 1 :
180- index += len ("score:" )
181- score = eval (output [index :])
182- else :
183- score = 0.0
178+ pattern = r"^[\d.]+$"
179+ if not re .match (pattern , output ):
180+ score = 0.0
181+ else :
182+ score = eval (output )
183+
184184 return score >= self .threshold
185185
186186 def _seed_question (self , context : str ) -> str :
@@ -241,22 +241,30 @@ def _generate_context(self, question: str, text_chunk: str) -> t.List[str]:
241241 for qstn in question .split ("\n " )
242242 ]
243243
244- def _remove_index (self , available_indices : list , node_idx : list ) -> t .List :
244+ def _remove_nodes (self , available_indices : list , node_idx : list ) -> t .List :
245245 for idx in node_idx :
246246 available_indices .remove (idx )
247247 return available_indices
248248
249- def _generate_doc_node_map (
249+ def _generate_doc_nodes_map (
250250 self , documenet_nodes : t .List [BaseNode ]
251- ) -> t .Dict [str , list ]:
252- doc_nodeidx = defaultdict (list )
253- for idx , node in enumerate (documenet_nodes ):
254- doc_nodeidx [node .id_ ].append (idx )
255-
256- return doc_nodeidx
257-
258- def _get_neighbour_node (self , idx : int , node_indices : list ) -> t .List [int ]:
259- return [idx - 1 , idx ] if idx == node_indices [- 1 ] else [idx , idx + 1 ]
251+ ) -> t .Dict [str , BaseNode ]:
252+ doc_nodes_map : t .Dict [str , t .List [BaseNode ]] = defaultdict (list [BaseNode ])
253+ for node in documenet_nodes :
254+ if node .ref_doc_id :
255+ doc_nodes_map [node .ref_doc_id ].append (node )
256+
257+ return doc_nodes_map # type: ignore
258+
259+ def _get_neighbour_node (
260+ self , node : BaseNode , related_nodes : list [BaseNode ]
261+ ) -> t .List [BaseNode ]:
262+ if len (related_nodes ) < 2 :
263+ warnings .warn ("No neighbors exists" )
264+ return [node ]
265+ idx = related_nodes .index (node )
266+ ids = [idx - 1 , idx ] if idx == (len (related_nodes ) - 1 ) else [idx , idx + 1 ]
267+ return [related_nodes [idx ] for idx in ids ]
260268
261269 def _embed_nodes (self , nodes : t .List [BaseNode ]) -> t .Dict [str , t .List [float ]]:
262270 embeddings = {}
@@ -275,38 +283,38 @@ def generate(self, documents: t.List[Document], test_size: int) -> TestDataset:
275283 document_nodes : t .List [BaseNode ] = node_parser .get_nodes_from_documents (
276284 documents = documents
277285 )
278-
279286 # maximum 1 seed question per node
280287 if test_size > len (document_nodes ):
281288 raise ValueError (
282289 """Maximum possible number of samples exceeded,
283290 reduce test_size or add more documents"""
284291 )
285292
286- available_indices = np .arange (0 , len (document_nodes )).tolist ()
287- doc_nodeidx = self ._generate_doc_node_map (document_nodes )
293+ available_nodes = document_nodes
294+ doc_nodes_map = self ._generate_doc_nodes_map (document_nodes )
295+ count_neighbours = sum (len (val ) > 1 for _ , val in doc_nodes_map .items ())
296+ if count_neighbours < len (documents ) // 2 :
297+ warnings .warn ("Most documents are too short" )
298+
288299 count = 0
289300 samples = []
290301
291302 pbar = tqdm (total = test_size )
292- while count < test_size and available_indices != []:
303+ while count < test_size and available_nodes != []:
293304 evolve_type = self ._get_evolve_type ()
294- node_idx = self .rng .choice (available_indices , size = 1 )[0 ]
295- available_indices = self ._remove_index ( available_indices , [node_idx ])
305+ curr_node = self .rng .choice (available_nodes , size = 1 )[0 ]
306+ available_nodes = self ._remove_nodes ( available_nodes , [curr_node ])
296307
297- neighbor_nodes = doc_nodeidx [
298- document_nodes [node_idx ].node_id # type: ignore
299- ]
308+ neighbor_nodes = doc_nodes_map [curr_node .source_node .node_id ]
300309
301310 # Append multiple nodes randomly to remove chunking bias
302311 size = self .rng .integers (1 , 3 )
303- node_indices = (
304- self ._get_neighbour_node (node_idx , neighbor_nodes )
312+ nodes = (
313+ self ._get_neighbour_node (curr_node , neighbor_nodes )
305314 if size > 1 and evolve_type != "multi_context"
306- else [node_idx ]
315+ else [curr_node ]
307316 )
308317
309- nodes = [document_nodes [node_idx ] for node_idx in node_indices ]
310318 text_chunk = " " .join ([node .get_content () for node in nodes ])
311319 score = self ._filter_context (text_chunk )
312320 if not score :
@@ -316,14 +324,13 @@ def generate(self, documents: t.List[Document], test_size: int) -> TestDataset:
316324 if evolve_type == "multi_context" :
317325 # Find most similar chunk in same document
318326 node_embedding = self ._embed_nodes ([nodes [- 1 ]])
319- neighbor_nodes = self ._remove_index (neighbor_nodes , node_indices )
320- neighbor_emb = self ._embed_nodes (
321- [document_nodes [idx ][0 ] for idx in neighbor_nodes ]
322- )
327+ neighbor_nodes = self ._remove_nodes (neighbor_nodes , nodes )
328+ neighbor_emb = self ._embed_nodes (neighbor_nodes )
329+
323330 _ , indices = get_top_k_embeddings (
324331 list (node_embedding .values ())[0 ],
325332 list (neighbor_emb .values ()),
326- similarity_cutoff = self .threshold ,
333+ similarity_cutoff = self .threshold / 10 ,
327334 )
328335 if indices :
329336 best_neighbor = neighbor_nodes [indices [0 ]]
@@ -332,7 +339,7 @@ def generate(self, documents: t.List[Document], test_size: int) -> TestDataset:
332339 context1 = text_chunk ,
333340 context2 = best_neighbor .get_content (),
334341 )
335- text_chunk = "\n " .join ([text_chunk , best_neighbor .get_context ()])
342+ text_chunk = "\n " .join ([text_chunk , best_neighbor .get_content ()])
336343 else :
337344 continue
338345
0 commit comments