Skip to content

Commit ee2653a

Browse files
committed
stepwise hint retrieval
1 parent 94fa1ab commit ee2653a

File tree

2 files changed

+32
-15
lines changed

2 files changed

+32
-15
lines changed

src/agentlab/agents/generic_agent_hinter/generic_agent.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from .generic_agent_prompt import (
2727
GenericPromptFlags,
2828
MainPrompt,
29-
StepWiseRetrievalPrompt,
29+
StepWiseContextIdentificationPrompt,
3030
)
3131

3232

@@ -111,10 +111,8 @@ def get_action(self, obs):
111111

112112
queries, think_queries = self._get_queries()
113113

114-
# TODO
115-
# use those queries to retreive from the database. e.g.:
116-
# hints = self.hint_db.get_hints(queries)
117-
# then add those hints to the main prompt
114+
# use those queries to retrieve from the database and pass to prompt if step-level
115+
queries_for_hints = queries if getattr(self.flags, "hint_level", "episode") == "step" else None
118116

119117
main_prompt = MainPrompt(
120118
action_set=self.action_set,
@@ -126,6 +124,7 @@ def get_action(self, obs):
126124
step=self.plan_step,
127125
flags=self.flags,
128126
llm=self.chat_llm,
127+
queries=queries_for_hints,
129128
)
130129

131130
# Set task name for task hints if available
@@ -183,7 +182,7 @@ def get_action(self, obs):
183182
def _get_queries(self):
184183
"""Retrieve queries for hinting."""
185184
system_prompt = SystemMessage(dp.SystemPrompt().prompt)
186-
query_prompt = StepWiseRetrievalPrompt(
185+
query_prompt = StepWiseContextIdentificationPrompt(
187186
obs_history=self.obs_history,
188187
actions=self.actions,
189188
thoughts=self.thoughts,

src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class GenericPromptFlags(dp.Flags):
6262
max_trunc_itr: int = 20
6363
flag_group: str = None
6464
n_retrieval_queries: int = 3
65+
hint_level: Literal["episode", "step"] = "episode"
6566

6667

6768
class MainPrompt(dp.Shrinkable):
@@ -76,6 +77,7 @@ def __init__(
7677
step: int,
7778
flags: GenericPromptFlags,
7879
llm: ChatModel,
80+
queries: list[str] | None = None,
7981
) -> None:
8082
super().__init__()
8183
self.flags = flags
@@ -118,6 +120,8 @@ def time_for_caution():
118120
hint_retrieval_mode=flags.task_hint_retrieval_mode,
119121
llm=llm,
120122
skip_hints_for_current_task=flags.skip_hints_for_current_task,
123+
hint_level=flags.hint_level,
124+
queries=queries,
121125
)
122126
self.plan = Plan(previous_plan, step, lambda: flags.use_plan) # TODO add previous plan
123127
self.criticise = Criticise(visible=lambda: flags.use_criticise)
@@ -306,6 +310,8 @@ def __init__(
306310
hint_retrieval_mode: Literal["direct", "llm", "emb"],
307311
skip_hints_for_current_task: bool,
308312
llm: ChatModel,
313+
hint_level: Literal["episode", "step"] = "episode",
314+
queries: list[str] | None = None,
309315
) -> None:
310316
super().__init__(visible=use_task_hint)
311317
self.use_task_hint = use_task_hint
@@ -315,6 +321,8 @@ def __init__(
315321
self.skip_hints_for_current_task = skip_hints_for_current_task
316322
self.goal = goal
317323
self.llm = llm
324+
self.hint_level: Literal["episode", "step"] = hint_level
325+
self.queries: list[str] | None = queries
318326
self._init()
319327

320328
_prompt = "" # Task hints are added dynamically in MainPrompt
@@ -352,6 +360,7 @@ def _init(self):
352360
else:
353361
print(f"Warning: Hint database not found at {hint_db_path}")
354362
self.hint_db = pd.DataFrame(columns=["task_name", "hint"])
363+
355364
self.hints_source = HintsSource(
356365
hint_db_path=hint_db_path.as_posix(),
357366
hint_retrieval_mode=self.hint_retrieval_mode,
@@ -380,7 +389,16 @@ def get_hints_for_task(self, task_name: str) -> str:
380389
return ""
381390

382391
try:
383-
task_hints = self.hints_source.choose_hints(self.llm, task_name, self.goal)
392+
# When step-level, pass queries as goal string to fit the llm_prompt
393+
goal_or_queries = self.goal
394+
if self.hint_level == "step" and self.queries:
395+
goal_or_queries = "\n".join(self.queries)
396+
397+
task_hints = self.hints_source.choose_hints(
398+
self.llm,
399+
task_name,
400+
goal_or_queries,
401+
)
384402

385403
hints = []
386404
for hint in task_hints:
@@ -400,14 +418,14 @@ def get_hints_for_task(self, task_name: str) -> str:
400418
return ""
401419

402420

403-
class StepWiseRetrievalPrompt(dp.Shrinkable):
421+
class StepWiseContextIdentificationPrompt(dp.Shrinkable):
404422
def __init__(
405423
self,
406424
obs_history: list[dict],
407425
actions: list[str],
408426
thoughts: list[str],
409427
obs_flags: dp.ObsFlags,
410-
n_queries: int = 3,
428+
n_queries: int = 1,
411429
) -> None:
412430
super().__init__()
413431
self.obs_flags = obs_flags
@@ -430,10 +448,10 @@ def _prompt(self) -> HumanMessage:
430448
)
431449

432450
example_queries = [
433-
"How to sort with multiple columns on the ServiceNow platform?",
434-
"What are the potential challenges of sorting by multiple columns?",
435-
"How to handle sorting by multiple columns in a table?",
436-
"Can I use the filter tool to sort by multiple columns?",
451+
"The user has started sorting a table and needs to apply multiple column criteria simultaneously.",
452+
"The user is attempting to configure advanced sorting options but the interface is unclear.",
453+
"The user has selected the first sort column and is now looking for how to add a second sort criterion.",
454+
"The user is in the middle of a multi-step sorting process and needs guidance on the next action.",
437455
]
438456

439457
example_queries_str = json.dumps(example_queries[: self.n_queries], indent=2)
@@ -442,8 +460,8 @@ def _prompt(self) -> HumanMessage:
442460
f"""
443461
# Querying memory
444462
445-
Before choosing an action, let's search our available documentation and memory on how to approach this step.
446-
This could provide valuable hints on how to properly solve this task. Return your answer as follow
463+
Before choosing an action, let's search our available documentation and memory for relevant context.
464+
Generate a brief, general summary of the current status to help identify useful hints. Return your answer as follow
447465
<think>chain of thought</think>
448466
<queries>json list of strings</queries> for the queries. Return exactly {self.n_queries}
449467
queries in the list.

0 commit comments

Comments
 (0)