Skip to content

Commit f06c6d0

Browse files
add new flag to skip hints with the current goal in the hint source t… (#310)
* add new flag to skip hints with the current goal in the hint source traces
1 parent 87e2510 commit f06c6d0

File tree

1 file changed

+23
-3
lines changed

1 file changed

+23
-3
lines changed

src/agentlab/utils/hinting.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import requests
1414
from agentlab.llm.chat_api import ChatModel
1515
import re
16+
import json
1617
from agentlab.llm.response_api import APIPayload
1718

1819
logger = logging.getLogger(__name__)
@@ -25,6 +26,7 @@ def __init__(
2526
hint_db_path: str,
2627
hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct",
2728
skip_hints_for_current_task: bool = False,
29+
skip_hints_for_current_goal: bool = False,
2830
top_n: int = 4,
2931
embedder_model: str = "Qwen/Qwen3-Embedding-0.6B",
3032
embedder_server: str = "http://localhost:5000",
@@ -36,6 +38,7 @@ def __init__(
3638
self.hint_db_path = hint_db_path
3739
self.hint_retrieval_mode = hint_retrieval_mode
3840
self.skip_hints_for_current_task = skip_hints_for_current_task
41+
self.skip_hints_for_current_goal = skip_hints_for_current_goal
3942
self.top_n = top_n
4043
self.embedder_model = embedder_model
4144
self.embedder_server = embedder_server
@@ -45,7 +48,16 @@ def __init__(
4548
self.hint_db_path = Path(hint_db_path).as_posix()
4649
else:
4750
self.hint_db_path = (Path(__file__).parent / self.hint_db_path).as_posix()
48-
self.hint_db = pd.read_csv(self.hint_db_path, header=0, index_col=None, dtype=str)
51+
self.hint_db = pd.read_csv(
52+
self.hint_db_path,
53+
header=0,
54+
index_col=None,
55+
dtype=str,
56+
converters={
57+
"trace_paths_json": lambda x: json.loads(x) if pd.notna(x) else [],
58+
"source_trace_goals": lambda x: json.loads(x) if pd.notna(x) else [],
59+
},
60+
)
4961
logger.info(f"Loaded {len(self.hint_db)} hints from database {self.hint_db_path}")
5062
if self.hint_retrieval_mode == "emb":
5163
self.load_hint_vectors()
@@ -84,7 +96,9 @@ def choose_hints_llm(self, llm, goal: str, task_name: str) -> list[str]:
8496
topic_to_hints = defaultdict(list)
8597
skip_hints = []
8698
if self.skip_hints_for_current_task:
87-
skip_hints = self.get_current_task_hints(task_name)
99+
skip_hints += self.get_current_task_hints(task_name)
100+
if self.skip_hints_for_current_goal:
101+
skip_hints += self.get_current_goal_hints(goal)
88102
for _, row in self.hint_db.iterrows():
89103
hint = row["hint"]
90104
if hint in skip_hints:
@@ -128,7 +142,9 @@ def choose_hints_emb(self, goal: str, task_name: str) -> list[str]:
128142
all_hints = self.uniq_hints["hint"].tolist()
129143
skip_hints = []
130144
if self.skip_hints_for_current_task:
131-
skip_hints = self.get_current_task_hints(task_name)
145+
skip_hints += self.get_current_task_hints(task_name)
146+
if self.skip_hints_for_current_goal:
147+
skip_hints += self.get_current_goal_hints(goal)
132148
hint_embeddings = []
133149
id_to_hint = {}
134150
for hint, emb in zip(all_hints, self.hint_embeddings):
@@ -199,3 +215,7 @@ def get_current_task_hints(self, task_name):
199215
self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name))
200216
]
201217
return hints_df["hint"].tolist()
218+
219+
def get_current_goal_hints(self, goal_str: str):
220+
mask = self.hint_db["source_trace_goals"].apply(lambda goals: goal_str in goals)
221+
return self.hint_db.loc[mask, "hint"].tolist()

0 commit comments

Comments
 (0)