Skip to content

Commit fcf42b3

Browse files
authored
Merge pull request #289 from ServiceNow/hints_retrieve
Hints retrieval in generic agent
2 parents 5405029 + e4cad16 commit fcf42b3

File tree

6 files changed

+189
-95
lines changed

6 files changed

+189
-95
lines changed

src/agentlab/agents/generic_agent_hinter/generic_agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,11 @@ def get_action(self, obs):
111111
previous_plan=self.plan,
112112
step=self.plan_step,
113113
flags=self.flags,
114+
llm=self.chat_llm,
114115
)
115116

116117
# Set task name for task hints if available
117-
if self.flags.use_task_hint and hasattr(self, 'task_name'):
118+
if self.flags.use_task_hint and hasattr(self, "task_name"):
118119
main_prompt.set_task_name(self.task_name)
119120

120121
max_prompt_tokens, max_trunc_itr = self._get_maxes()

src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@
66

77
import logging
88
from dataclasses import dataclass
9+
from pathlib import Path
10+
from typing import Literal
911

10-
from browsergym.core import action
12+
import pandas as pd
1113
from browsergym.core.action.base import AbstractActionSet
1214

1315
from agentlab.agents import dynamic_prompting as dp
16+
from agentlab.agents.tool_use_agent.tool_use_agent import HintsSource
17+
from agentlab.llm.chat_api import ChatModel
1418
from agentlab.llm.llm_utils import HumanMessage, parse_html_tags_raise
15-
import fnmatch
16-
import pandas as pd
17-
from pathlib import Path
1819

1920

2021
@dataclass
@@ -49,6 +50,8 @@ class GenericPromptFlags(dp.Flags):
4950
use_abstract_example: bool = False
5051
use_hints: bool = False
5152
use_task_hint: bool = False
53+
task_hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct"
54+
skip_hints_for_current_task: bool = False
5255
hint_db_path: str = None
5356
enable_chat: bool = False
5457
max_prompt_tokens: int = None
@@ -70,10 +73,12 @@ def __init__(
7073
previous_plan: str,
7174
step: int,
7275
flags: GenericPromptFlags,
76+
llm: ChatModel,
7377
) -> None:
7478
super().__init__()
7579
self.flags = flags
7680
self.history = dp.History(obs_history, actions, memories, thoughts, flags.obs)
81+
goal = obs_history[-1]["goal_object"]
7782
if self.flags.enable_chat:
7883
self.instructions = dp.ChatInstructions(
7984
obs_history[-1]["chat_messages"], extra_instructions=flags.extra_instructions
@@ -84,7 +89,7 @@ def __init__(
8489
"Agent is in goal mode, but multiple user messages are present in the chat. Consider switching to `enable_chat=True`."
8590
)
8691
self.instructions = dp.GoalInstructions(
87-
obs_history[-1]["goal_object"], extra_instructions=flags.extra_instructions
92+
goal, extra_instructions=flags.extra_instructions
8893
)
8994

9095
self.obs = dp.Observation(
@@ -103,9 +108,14 @@ def time_for_caution():
103108
self.be_cautious = dp.BeCautious(visible=time_for_caution)
104109
self.think = dp.Think(visible=lambda: flags.use_thinking)
105110
self.hints = dp.Hints(visible=lambda: flags.use_hints)
111+
goal_str: str = goal[0]["text"]
106112
self.task_hint = TaskHint(
107113
use_task_hint=flags.use_task_hint,
108-
hint_db_path=flags.hint_db_path
114+
hint_db_path=flags.hint_db_path,
115+
goal=goal_str,
116+
hint_retrieval_mode=flags.task_hint_retrieval_mode,
117+
llm=llm,
118+
skip_hints_for_current_task=flags.skip_hints_for_current_task,
109119
)
110120
self.plan = Plan(previous_plan, step, lambda: flags.use_plan) # TODO add previous plan
111121
self.criticise = Criticise(visible=lambda: flags.use_criticise)
@@ -114,12 +124,12 @@ def time_for_caution():
114124
@property
115125
def _prompt(self) -> HumanMessage:
116126
prompt = HumanMessage(self.instructions.prompt)
117-
127+
118128
# Add task hints if enabled
119129
task_hints_text = ""
120-
if self.flags.use_task_hint and hasattr(self, 'task_name'):
130+
if self.flags.use_task_hint and hasattr(self, "task_name"):
121131
task_hints_text = self.task_hint.get_hints_for_task(self.task_name)
122-
132+
123133
prompt.add_text(
124134
f"""\
125135
{self.obs.prompt}\
@@ -286,11 +296,23 @@ def _parse_answer(self, text_answer):
286296

287297

288298
class TaskHint(dp.PromptElement):
289-
def __init__(self, use_task_hint: bool = True, hint_db_path: str = None) -> None:
299+
def __init__(
300+
self,
301+
use_task_hint: bool,
302+
hint_db_path: str,
303+
goal: str,
304+
hint_retrieval_mode: Literal["direct", "llm", "emb"],
305+
skip_hints_for_current_task: bool,
306+
llm: ChatModel,
307+
) -> None:
290308
super().__init__(visible=use_task_hint)
291309
self.use_task_hint = use_task_hint
292310
self.hint_db_rel_path = "hint_db.csv"
293311
self.hint_db_path = hint_db_path # Allow external path override
312+
self.hint_retrieval_mode: Literal["direct", "llm", "emb"] = hint_retrieval_mode
313+
self.skip_hints_for_current_task = skip_hints_for_current_task
314+
self.goal = goal
315+
self.llm = llm
294316
self._init()
295317

296318
_prompt = "" # Task hints are added dynamically in MainPrompt
@@ -316,42 +338,50 @@ def _init(self):
316338
hint_db_path = Path(self.hint_db_path)
317339
else:
318340
hint_db_path = Path(__file__).parent / self.hint_db_rel_path
319-
341+
320342
if hint_db_path.exists():
321343
self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str)
322344
# Verify the expected columns exist
323345
if "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns:
324-
print(f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}")
346+
print(
347+
f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}"
348+
)
325349
self.hint_db = pd.DataFrame(columns=["task_name", "hint"])
326350
else:
327351
print(f"Warning: Hint database not found at {hint_db_path}")
328352
self.hint_db = pd.DataFrame(columns=["task_name", "hint"])
353+
self.hints_source = HintsSource(
354+
hint_db_path=hint_db_path.as_posix(),
355+
hint_retrieval_mode=self.hint_retrieval_mode,
356+
skip_hints_for_current_task=self.skip_hints_for_current_task,
357+
)
329358
except Exception as e:
330359
# Fallback to empty database on any error
331360
print(f"Warning: Could not load hint database: {e}")
332361
self.hint_db = pd.DataFrame(columns=["task_name", "hint"])
333362

334-
335363
def get_hints_for_task(self, task_name: str) -> str:
336364
"""Get hints for a specific task."""
337365
if not self.use_task_hint:
338366
return ""
339367

340368
# Ensure hint_db is initialized
341-
if not hasattr(self, 'hint_db'):
369+
if not hasattr(self, "hint_db"):
342370
self._init()
343371

344372
# Check if hint_db has the expected structure
345-
if self.hint_db.empty or "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns:
373+
if (
374+
self.hint_db.empty
375+
or "task_name" not in self.hint_db.columns
376+
or "hint" not in self.hint_db.columns
377+
):
346378
return ""
347379

348380
try:
349-
task_hints = self.hint_db[
350-
self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name))
351-
]
381+
task_hints = self.hints_source.choose_hints(self.llm, task_name, self.goal)
352382

353383
hints = []
354-
for hint in task_hints["hint"]:
384+
for hint in task_hints:
355385
hint = hint.strip()
356386
if hint:
357387
hints.append(f"- {hint}")
@@ -364,5 +394,5 @@ def get_hints_for_task(self, task_name: str) -> str:
364394
return hints_str
365395
except Exception as e:
366396
print(f"Warning: Error getting hints for task {task_name}: {e}")
367-
397+
368398
return ""

0 commit comments

Comments
 (0)