@@ -411,58 +411,62 @@ def choose_hints(self, llm, task_name: str, goal: str) -> list[str]:
411411 def choose_hints_llm (self , llm , goal : str , task_name : str ) -> list [str ]:
412412 """Choose hints using LLM to filter the hints."""
413413 topic_to_hints = defaultdict (list )
414- hints_df = self . hint_db
414+ skip_hints = []
415415 if self .skip_hints_for_current_task :
416- current_task_hints = self .get_current_task_hints (task_name )
417- hints_df = hints_df [ ~ hints_df [ "hint" ]. isin ( current_task_hints )]
418- logger . info (
419- f"Filtered out current task hints, remaining hints: { hints_df . shape [ 0 ] } out of { self . hint_db . shape [ 0 ] } "
420- )
421- for i , row in hints_df . iterrows ():
422- topic_to_hints [ row [ "semantic_keys" ]]. append ( i )
416+ skip_hints = self .get_current_task_hints (task_name )
417+ for _ , row in self . hint_db . iterrows ():
418+ hint = row [ "hint" ]
419+ if hint in skip_hints :
420+ continue
421+ topic_to_hints [ row [ "semantic_keys" ]]. append ( hint )
422+ logger . info ( f"Collected { len ( topic_to_hints ) } hint topics" )
423423 hint_topics = list (topic_to_hints .keys ())
424424 topics = "\n " .join ([f"{ i } . { h } " for i , h in enumerate (hint_topics )])
425425 prompt = self .llm_prompt .format (goal = goal , topics = topics )
426+
426427 if isinstance (llm , ChatModel ):
427428 response : str = llm (messages = [dict (role = "user" , content = prompt )])["content" ]
428429 else :
429430 response : str = llm (APIPayload (messages = [llm .msg .user ().add_text (prompt )])).think
430431 try :
431- hint_topic_idx = json .loads (response )
432- if hint_topic_idx < 0 or hint_topic_idx >= len (hint_topics ):
432+ topic_number = json .loads (response )
433+ if topic_number < 0 or topic_number >= len (hint_topics ):
433434 logger .error (f"Wrong LLM hint id response: { response } , no hints" )
434435 return []
435- hint_topic = hint_topics [hint_topic_idx ]
436- hint_indices = topic_to_hints [hint_topic ]
437- df = hints_df .iloc [hint_indices ].copy ()
438- df = df .drop_duplicates (subset = ["hint" ], keep = "first" ) # leave only unique hints
439- hints = df ["hint" ].tolist ()
440- logger .info (f"LLM hint topic { hint_topic_idx } , chosen hints: { df ['hint' ].tolist ()} " )
441- except json .JSONDecodeError :
442- logger .error (f"Failed to parse LLM hint id response: { response } , no hints" )
436+ hint_topic = hint_topics [topic_number ]
437+ hints = list (set (topic_to_hints [hint_topic ]))
438+ logger .info (f"LLM hint topic { topic_number } :'{ hint_topic } ', chosen hints: { hints } " )
439+ except Exception as e :
440+ logger .exception (f"Failed to parse LLM hint id response: { response } :\n { e } " )
443441 hints = []
444442 return hints
445443
446444 def choose_hints_emb (self , goal : str , task_name : str ) -> list [str ]:
447445 """Choose hints using embeddings to filter the hints."""
448- goal_embeddings = self ._encode ([goal ], prompt = "task description" )
449- hint_embeddings = self .hint_embeddings
450- hints_df = self .uniq_hints
451- if self .skip_hints_for_current_task :
452- current_task_hints = self .get_current_task_hints (task_name )
453- mask = ~ hints_df ["hint" ].isin (current_task_hints )
454- hints_df = hints_df [mask ]
455- filtered_indices = hints_df .index .tolist ()
456- hint_embeddings = hint_embeddings [filtered_indices ]
457- logger .info (
458- f"Filtered same task hint, remained: { len (hint_embeddings )} out of { len (self .hint_embeddings )} embeddings"
459- )
460- similarities = self ._similarity (goal_embeddings .tolist (), hint_embeddings .tolist ())
461- top_indices = similarities .argsort ()[0 ][- self .top_n :].tolist ()
462- logger .info (f"Top hint indices based on embedding similarity: { top_indices } " )
463- hints = hints_df .iloc [top_indices ]
464- logger .info (f"Embedding-based hints chosen: { hints } " )
465- return hints ["hint" ].tolist ()
446+ try :
447+ goal_embeddings = self ._encode ([goal ], prompt = "task description" )
448+ hint_embeddings = self .hint_embeddings .copy ()
449+ all_hints = self .uniq_hints ["hint" ].tolist ()
450+ skip_hints = []
451+ if self .skip_hints_for_current_task :
452+ skip_hints = self .get_current_task_hints (task_name )
453+ hint_embeddings = []
454+ id_to_hint = {}
455+ for hint , emb in zip (all_hints , self .hint_embeddings ):
456+ if hint in skip_hints :
457+ continue
458+ hint_embeddings .append (emb .tolist ())
459+ id_to_hint [len (hint_embeddings ) - 1 ] = hint
460+ logger .info (f"Prepared hint embeddings for { len (hint_embeddings )} hints" )
461+ similarities = self ._similarity (goal_embeddings .tolist (), hint_embeddings )
462+ top_indices = similarities .argsort ()[0 ][- self .top_n :].tolist ()
463+ logger .info (f"Top hint indices based on embedding similarity: { top_indices } " )
464+ hints = [id_to_hint [idx ] for idx in top_indices ]
465+ logger .info (f"Embedding-based hints chosen: { hints } " )
466+ except Exception as e :
467+ logger .exception (f"Failed to choose hints using embeddings: { e } " )
468+ hints = []
469+ return hints
466470
467471 def _encode (self , texts : list [str ], prompt : str = "" , timeout : int = 10 , max_retries : int = 5 ):
468472 """Call the encode API endpoint with timeout and retries"""
@@ -483,7 +487,11 @@ def _encode(self, texts: list[str], prompt: str = "", timeout: int = 10, max_ret
483487 raise ValueError ("Failed to encode hints" )
484488
485489 def _similarity (
486- self , texts1 : list [str ], texts2 : list [str ], timeout : int = 2 , max_retries : int = 5
490+ self ,
491+ texts1 : list ,
492+ texts2 : list ,
493+ timeout : int = 2 ,
494+ max_retries : int = 5 ,
487495 ):
488496 """Call the similarity API endpoint with timeout and retries"""
489497 for attempt in range (max_retries ):
0 commit comments