Skip to content

Commit dabddcf

Browse files
authored
Merge pull request #291 from ServiceNow/step-wise-retrieval
Add StepWiseQueriesPrompt for enhanced query handling in GenericAgent
2 parents 69048c4 + ca11170 commit dabddcf

File tree

3 files changed

+136
-4
lines changed

3 files changed

+136
-4
lines changed

src/agentlab/agents/generic_agent/generic_agent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def obs_preprocessor(self, obs: dict) -> dict:
9898
def get_action(self, obs):
9999

100100
self.obs_history.append(obs)
101+
101102
main_prompt = MainPrompt(
102103
action_set=self.action_set,
103104
obs_history=self.obs_history,

src/agentlab/agents/generic_agent_hinter/generic_agent.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@
2323
from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry
2424
from agentlab.llm.tracking import cost_tracker_decorator
2525

26-
from .generic_agent_prompt import GenericPromptFlags, MainPrompt
26+
from .generic_agent_prompt import (
27+
GenericPromptFlags,
28+
MainPrompt,
29+
StepWiseContextIdentificationPrompt,
30+
)
2731

2832

2933
@dataclass
@@ -102,6 +106,14 @@ def set_task_name(self, task_name: str):
102106
def get_action(self, obs):
103107

104108
self.obs_history.append(obs)
109+
110+
system_prompt = SystemMessage(dp.SystemPrompt().prompt)
111+
112+
queries, think_queries = self._get_queries()
113+
114+
# use those queries to retrieve from the database and pass to prompt if step-level
115+
queries_for_hints = queries if getattr(self.flags, "hint_level", "episode") == "step" else None
116+
105117
main_prompt = MainPrompt(
106118
action_set=self.action_set,
107119
obs_history=self.obs_history,
@@ -112,6 +124,7 @@ def get_action(self, obs):
112124
step=self.plan_step,
113125
flags=self.flags,
114126
llm=self.chat_llm,
127+
queries=queries_for_hints,
115128
)
116129

117130
# Set task name for task hints if available
@@ -120,8 +133,6 @@ def get_action(self, obs):
120133

121134
max_prompt_tokens, max_trunc_itr = self._get_maxes()
122135

123-
system_prompt = SystemMessage(dp.SystemPrompt().prompt)
124-
125136
human_prompt = dp.fit_tokens(
126137
shrinkable=main_prompt,
127138
max_prompt_tokens=max_prompt_tokens,
@@ -168,6 +179,31 @@ def get_action(self, obs):
168179
)
169180
return ans_dict["action"], agent_info
170181

182+
def _get_queries(self):
183+
"""Retrieve queries for hinting."""
184+
system_prompt = SystemMessage(dp.SystemPrompt().prompt)
185+
query_prompt = StepWiseContextIdentificationPrompt(
186+
obs_history=self.obs_history,
187+
actions=self.actions,
188+
thoughts=self.thoughts,
189+
obs_flags=self.flags.obs,
190+
n_queries=self.flags.n_retrieval_queries, # TODO
191+
)
192+
193+
chat_messages = Discussion([system_prompt, query_prompt.prompt])
194+
ans_dict = retry(
195+
self.chat_llm,
196+
chat_messages,
197+
n_retry=self.max_retry,
198+
parser=query_prompt._parse_answer,
199+
)
200+
201+
queries = ans_dict.get("queries", [])
202+
assert len(queries) == self.flags.n_retrieval_queries
203+
204+
# TODO: we should probably propagate these chat_messages to be able to see them in xray
205+
return queries, ans_dict.get("think", None)
206+
171207
def reset(self, seed=None):
172208
self.seed = seed
173209
self.plan = "No plan yet"

src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
It is based on the dynamic_prompting module from the agentlab package.
55
"""
66

7+
import json
78
import logging
89
from dataclasses import dataclass
910
from pathlib import Path
@@ -67,6 +68,8 @@ class GenericPromptFlags(dp.Flags):
6768
hint_index_path: str = None
6869
hint_retriever_path: str = None
6970
hint_num_results: int = 5
71+
n_retrieval_queries: int = 3
72+
hint_level: Literal["episode", "step"] = "episode"
7073

7174

7275
class MainPrompt(dp.Shrinkable):
@@ -81,6 +84,7 @@ def __init__(
8184
step: int,
8285
flags: GenericPromptFlags,
8386
llm: ChatModel,
87+
queries: list[str] | None = None,
8488
) -> None:
8589
super().__init__()
8690
self.flags = flags
@@ -130,6 +134,8 @@ def time_for_caution():
130134
hint_index_path=flags.hint_index_path,
131135
hint_retriever_path=flags.hint_retriever_path,
132136
hint_num_results=flags.hint_num_results,
137+
hint_level=flags.hint_level,
138+
queries=queries,
133139
)
134140
self.plan = Plan(previous_plan, step, lambda: flags.use_plan) # TODO add previous plan
135141
self.criticise = Criticise(visible=lambda: flags.use_criticise)
@@ -324,6 +330,8 @@ def __init__(
324330
hint_num_results: int = 5,
325331
skip_hints_for_current_task: bool = False,
326332
hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct",
333+
hint_level: Literal["episode", "step"] = "episode",
334+
queries: list[str] | None = None,
327335
) -> None:
328336
super().__init__(visible=use_task_hint)
329337
self.use_task_hint = use_task_hint
@@ -339,6 +347,8 @@ def __init__(
339347
self.skip_hints_for_current_task = skip_hints_for_current_task
340348
self.goal = goal
341349
self.llm = llm
350+
self.hint_level: Literal["episode", "step"] = hint_level
351+
self.queries: list[str] | None = queries
342352
self._init()
343353

344354
_prompt = "" # Task hints are added dynamically in MainPrompt
@@ -394,6 +404,7 @@ def _init(self):
394404
else:
395405
print(f"Warning: Hint database not found at {hint_db_path}")
396406
self.hint_db = pd.DataFrame(columns=["task_name", "hint"])
407+
397408
self.hints_source = HintsSource(
398409
hint_db_path=hint_db_path.as_posix(),
399410
hint_retrieval_mode=self.hint_retrieval_mode,
@@ -448,7 +459,16 @@ def get_hints_for_task(self, task_name: str) -> str:
448459
return ""
449460

450461
try:
451-
task_hints = self.hints_source.choose_hints(self.llm, task_name, self.goal)
462+
# When step-level, pass queries as goal string to fit the llm_prompt
463+
goal_or_queries = self.goal
464+
if self.hint_level == "step" and self.queries:
465+
goal_or_queries = "\n".join(self.queries)
466+
467+
task_hints = self.hints_source.choose_hints(
468+
self.llm,
469+
task_name,
470+
goal_or_queries,
471+
)
452472

453473
hints = []
454474
for hint in task_hints:
@@ -466,3 +486,78 @@ def get_hints_for_task(self, task_name: str) -> str:
466486
print(f"Warning: Error getting hints for task {task_name}: {e}")
467487

468488
return ""
489+
490+
491+
class StepWiseContextIdentificationPrompt(dp.Shrinkable):
492+
def __init__(
493+
self,
494+
obs_history: list[dict],
495+
actions: list[str],
496+
thoughts: list[str],
497+
obs_flags: dp.ObsFlags,
498+
n_queries: int = 1,
499+
) -> None:
500+
super().__init__()
501+
self.obs_flags = obs_flags
502+
self.n_queries = n_queries
503+
self.history = dp.History(obs_history, actions, None, thoughts, obs_flags)
504+
self.instructions = dp.GoalInstructions(obs_history[-1]["goal_object"])
505+
self.obs = dp.Observation(obs_history[-1], obs_flags)
506+
507+
self.think = dp.Think(visible=True) # To replace with static text maybe
508+
509+
@property
510+
def _prompt(self) -> HumanMessage:
511+
prompt = HumanMessage(self.instructions.prompt)
512+
513+
prompt.add_text(
514+
f"""\
515+
{self.obs.prompt}\
516+
{self.history.prompt}\
517+
"""
518+
)
519+
520+
example_queries = [
521+
"The user has started sorting a table and needs to apply multiple column criteria simultaneously.",
522+
"The user is attempting to configure advanced sorting options but the interface is unclear.",
523+
"The user has selected the first sort column and is now looking for how to add a second sort criterion.",
524+
"The user is in the middle of a multi-step sorting process and needs guidance on the next action.",
525+
]
526+
527+
example_queries_str = json.dumps(example_queries[: self.n_queries], indent=2)
528+
529+
prompt.add_text(
530+
f"""
531+
# Querying memory
532+
533+
Before choosing an action, let's search our available documentation and memory for relevant context.
534+
Generate a brief, general summary of the current status to help identify useful hints. Return your answer as follow
535+
<think>chain of thought</think>
536+
<queries>json list of strings</queries> for the queries. Return exactly {self.n_queries}
537+
queries in the list.
538+
539+
# Concrete Example
540+
541+
<think>
542+
I have to sort by client and country. I could use the built-in sort on each column but I'm not sure if
543+
I will be able to sort by both at the same time.
544+
</think>
545+
546+
<queries>
547+
{example_queries_str}
548+
</queries>
549+
"""
550+
)
551+
552+
return self.obs.add_screenshot(prompt)
553+
554+
def shrink(self):
555+
self.history.shrink()
556+
self.obs.shrink()
557+
558+
def _parse_answer(self, text_answer):
559+
ans_dict = parse_html_tags_raise(
560+
text_answer, keys=["think", "queries"], merge_multiple=True
561+
)
562+
ans_dict["queries"] = json.loads(ans_dict.get("queries", "[]"))
563+
return ans_dict

0 commit comments

Comments
 (0)