@@ -63,12 +63,23 @@ class Evolution:
6363
6464 @staticmethod
6565 def merge_nodes (nodes : CurrentNodes ) -> Node :
66- return Node (
66+ # TODO: while merging merge according to the order of documents
67+ # if any nodes from same document take account their page order
68+
69+ new_node = Node (
6770 doc_id = "merged" ,
6871 page_content = "\n " .join (n .page_content for n in nodes .nodes ),
6972 keyphrases = [phrase for n in nodes .nodes for phrase in n .keyphrases ],
7073 )
7174
75+ embed_dim = len (nodes .nodes [0 ].embedding ) if nodes .nodes [0 ].embedding else None
76+ if embed_dim :
77+ node_embeddings = np .array ([n .embedding for n in nodes .nodes ]).reshape (
78+ - 1 , embed_dim
79+ )
80+ new_node .embedding = np .average (node_embeddings , axis = 0 )
81+ return new_node
82+
7283 def init (self , is_async : bool = True , run_config : t .Optional [RunConfig ] = None ):
7384 self .is_async = is_async
7485 if run_config is None :
@@ -191,7 +202,10 @@ async def generate_datarow(
191202 root_node = current_nodes .root_node , nodes = current_nodes .nodes
192203 )
193204 else :
194- relevant_context = current_nodes
205+ selected_nodes = [current_nodes .nodes [i ] for i in relevant_context_indices ]
206+ relevant_context = CurrentNodes (
207+ root_node = selected_nodes [0 ], nodes = selected_nodes
208+ )
195209
196210 merged_nodes = self .merge_nodes (relevant_context )
197211 results = await self .generator_llm .generate (
@@ -207,7 +221,7 @@ async def generate_datarow(
207221
208222 return DataRow (
209223 question = question ,
210- contexts = [n .page_content for n in current_nodes .nodes ],
224+ contexts = [n .page_content for n in relevant_context .nodes ],
211225 ground_truth = "" if answer is None else answer ,
212226 evolution_type = evolution_type ,
213227 )
@@ -253,7 +267,7 @@ async def _aevolve(
253267 assert self .question_filter is not None , "question_filter cannot be None"
254268
255269 merged_node = self .merge_nodes (current_nodes )
256- passed = await self .node_filter .filter (current_nodes . root_node )
270+ passed = await self .node_filter .filter (merged_node )
257271 if not passed ["score" ]:
258272 nodes = self .docstore .get_random_nodes (k = 1 )
259273 new_current_nodes = CurrentNodes (root_node = nodes [0 ], nodes = nodes )
@@ -334,16 +348,19 @@ async def _acomplex_evolution(
334348 assert self .question_filter is not None , "question_filter cannot be None"
335349 assert self .se is not None , "simple evolution cannot be None"
336350
337- simple_question , _ , _ = await self .se ._aevolve (current_tries , current_nodes )
351+ simple_question , current_nodes , _ = await self .se ._aevolve (
352+ current_tries , current_nodes
353+ )
338354 logger .debug (
339355 "[%s] simple question generated: %s" ,
340356 self .__class__ .__name__ ,
341357 simple_question ,
342358 )
343359
360+ merged_node = self .merge_nodes (current_nodes )
344361 result = await self .generator_llm .generate (
345362 prompt = question_prompt .format (
346- question = simple_question , context = current_nodes . root_node .page_content
363+ question = simple_question , context = merged_node .page_content
347364 )
348365 )
349366 reasoning_question = result .generations [0 ][0 ].text .strip ()
@@ -409,42 +426,53 @@ async def _aevolve(
409426 assert self .question_filter is not None , "question_filter cannot be None"
410427 assert self .se is not None , "simple evolution cannot be None"
411428
412- simple_question , _ , _ = await self .se ._aevolve (current_tries , current_nodes )
429+ simple_question , current_nodes , _ = await self .se ._aevolve (
430+ current_tries , current_nodes
431+ )
413432 logger .debug (
414433 "[MultiContextEvolution] simple question generated: %s" , simple_question
415434 )
416-
417435 # find a similar node and generate a question based on both
418- similar_node = self .docstore .get_similar (current_nodes .root_node )
436+ merged_node = self .merge_nodes (current_nodes )
437+ similar_node = self .docstore .get_similar (merged_node , top_k = 1 )
419438 if similar_node == []:
420439 # retry
421- current_nodes = self .se ._get_more_adjacent_nodes (current_nodes )
440+ new_random_nodes = self .docstore .get_random_nodes (k = 1 )
441+ current_nodes = CurrentNodes (
442+ root_node = new_random_nodes [0 ], nodes = new_random_nodes
443+ )
422444 return await self .aretry_evolve (current_tries , current_nodes )
445+ else :
446+ assert isinstance (similar_node [0 ], Node ), "similar_node must be a Node"
447+ current_nodes = CurrentNodes (
448+ root_node = merged_node , nodes = [merged_node , similar_node [0 ]]
449+ )
423450
424451 prompt = self .multi_context_question_prompt .format (
425452 question = simple_question ,
426- context1 = current_nodes . root_node .page_content ,
427- context2 = similar_node ,
453+ context1 = merged_node .page_content ,
454+ context2 = similar_node [ 0 ]. page_content ,
428455 )
429456 results = await self .generator_llm .generate (prompt = prompt )
430457 question = results .generations [0 ][0 ].text .strip ()
431458 logger .debug (
432459 "[MultiContextEvolution] multicontext question generated: %s" , question
433460 )
434461
462+ if not await self .question_filter .filter (question ):
463+ # retry
464+ current_nodes = self .se ._get_more_adjacent_nodes (current_nodes )
465+ return await self .aretry_evolve (current_tries , current_nodes )
466+
435467 # compress the question
436468 compressed_question = await self ._transform_question (
437469 prompt = self .compress_question_prompt , question = question
438470 )
439471 logger .debug (
440- "[MultiContextEvolution] multicontext question compressed: %s" , question
472+ "[MultiContextEvolution] multicontext question compressed: %s" ,
473+ compressed_question ,
441474 )
442475
443- if not await self .question_filter .filter (compressed_question ):
444- # retry
445- current_nodes = self .se ._get_more_adjacent_nodes (current_nodes )
446- return await self .aretry_evolve (current_tries , current_nodes )
447-
448476 assert self .evolution_filter is not None , "evolution filter cannot be None"
449477 if await self .evolution_filter .filter (simple_question , compressed_question ):
450478 # retry
0 commit comments