1313import requests
1414from agentlab .llm .chat_api import ChatModel
1515import re
16+ import json
1617from agentlab .llm .response_api import APIPayload
1718
1819logger = logging .getLogger (__name__ )
@@ -25,6 +26,7 @@ def __init__(
2526 hint_db_path : str ,
2627 hint_retrieval_mode : Literal ["direct" , "llm" , "emb" ] = "direct" ,
2728 skip_hints_for_current_task : bool = False ,
29+ skip_hints_for_current_goal : bool = False ,
2830 top_n : int = 4 ,
2931 embedder_model : str = "Qwen/Qwen3-Embedding-0.6B" ,
3032 embedder_server : str = "http://localhost:5000" ,
@@ -36,6 +38,7 @@ def __init__(
3638 self .hint_db_path = hint_db_path
3739 self .hint_retrieval_mode = hint_retrieval_mode
3840 self .skip_hints_for_current_task = skip_hints_for_current_task
41+ self .skip_hints_for_current_goal = skip_hints_for_current_goal
3942 self .top_n = top_n
4043 self .embedder_model = embedder_model
4144 self .embedder_server = embedder_server
@@ -45,7 +48,16 @@ def __init__(
4548 self .hint_db_path = Path (hint_db_path ).as_posix ()
4649 else :
4750 self .hint_db_path = (Path (__file__ ).parent / self .hint_db_path ).as_posix ()
48- self .hint_db = pd .read_csv (self .hint_db_path , header = 0 , index_col = None , dtype = str )
51+ self .hint_db = pd .read_csv (
52+ self .hint_db_path ,
53+ header = 0 ,
54+ index_col = None ,
55+ dtype = str ,
56+ converters = {
57+ "trace_paths_json" : lambda x : json .loads (x ) if pd .notna (x ) else [],
58+ "source_trace_goals" : lambda x : json .loads (x ) if pd .notna (x ) else [],
59+ },
60+ )
4961 logger .info (f"Loaded { len (self .hint_db )} hints from database { self .hint_db_path } " )
5062 if self .hint_retrieval_mode == "emb" :
5163 self .load_hint_vectors ()
@@ -84,7 +96,9 @@ def choose_hints_llm(self, llm, goal: str, task_name: str) -> list[str]:
8496 topic_to_hints = defaultdict (list )
8597 skip_hints = []
8698 if self .skip_hints_for_current_task :
87- skip_hints = self .get_current_task_hints (task_name )
99+ skip_hints += self .get_current_task_hints (task_name )
100+ if self .skip_hints_for_current_goal :
101+ skip_hints += self .get_current_goal_hints (goal )
88102 for _ , row in self .hint_db .iterrows ():
89103 hint = row ["hint" ]
90104 if hint in skip_hints :
@@ -128,7 +142,9 @@ def choose_hints_emb(self, goal: str, task_name: str) -> list[str]:
128142 all_hints = self .uniq_hints ["hint" ].tolist ()
129143 skip_hints = []
130144 if self .skip_hints_for_current_task :
131- skip_hints = self .get_current_task_hints (task_name )
145+ skip_hints += self .get_current_task_hints (task_name )
146+ if self .skip_hints_for_current_goal :
147+ skip_hints += self .get_current_goal_hints (goal )
132148 hint_embeddings = []
133149 id_to_hint = {}
134150 for hint , emb in zip (all_hints , self .hint_embeddings ):
@@ -199,3 +215,7 @@ def get_current_task_hints(self, task_name):
199215 self .hint_db ["task_name" ].apply (lambda x : fnmatch .fnmatch (x , task_name ))
200216 ]
201217 return hints_df ["hint" ].tolist ()
218+
219+ def get_current_goal_hints (self , goal_str : str ):
220+ mask = self .hint_db ["source_trace_goals" ].apply (lambda goals : goal_str in goals )
221+ return self .hint_db .loc [mask , "hint" ].tolist ()
0 commit comments