Skip to content

Commit 94fa1ab

Browse files
committed
Add StepWiseQueriesPrompt for enhanced query handling in GenericAgent
1 parent fcf42b3 commit 94fa1ab

File tree

3 files changed

+114
-3
lines changed

3 files changed

+114
-3
lines changed

src/agentlab/agents/generic_agent/generic_agent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def obs_preprocessor(self, obs: dict) -> dict:
9898
def get_action(self, obs):
9999

100100
self.obs_history.append(obs)
101+
101102
main_prompt = MainPrompt(
102103
action_set=self.action_set,
103104
obs_history=self.obs_history,

src/agentlab/agents/generic_agent_hinter/generic_agent.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@
2323
from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry
2424
from agentlab.llm.tracking import cost_tracker_decorator
2525

26-
from .generic_agent_prompt import GenericPromptFlags, MainPrompt
26+
from .generic_agent_prompt import (
27+
GenericPromptFlags,
28+
MainPrompt,
29+
StepWiseRetrievalPrompt,
30+
)
2731

2832

2933
@dataclass
@@ -102,6 +106,16 @@ def set_task_name(self, task_name: str):
102106
def get_action(self, obs):
103107

104108
self.obs_history.append(obs)
109+
110+
system_prompt = SystemMessage(dp.SystemPrompt().prompt)
111+
112+
queries, think_queries = self._get_queries()
113+
114+
# TODO
115+
# use those queries to retreive from the database. e.g.:
116+
# hints = self.hint_db.get_hints(queries)
117+
# then add those hints to the main prompt
118+
105119
main_prompt = MainPrompt(
106120
action_set=self.action_set,
107121
obs_history=self.obs_history,
@@ -120,8 +134,6 @@ def get_action(self, obs):
120134

121135
max_prompt_tokens, max_trunc_itr = self._get_maxes()
122136

123-
system_prompt = SystemMessage(dp.SystemPrompt().prompt)
124-
125137
human_prompt = dp.fit_tokens(
126138
shrinkable=main_prompt,
127139
max_prompt_tokens=max_prompt_tokens,
@@ -168,6 +180,31 @@ def get_action(self, obs):
168180
)
169181
return ans_dict["action"], agent_info
170182

183+
def _get_queries(self):
184+
"""Retrieve queries for hinting."""
185+
system_prompt = SystemMessage(dp.SystemPrompt().prompt)
186+
query_prompt = StepWiseRetrievalPrompt(
187+
obs_history=self.obs_history,
188+
actions=self.actions,
189+
thoughts=self.thoughts,
190+
obs_flags=self.flags.obs,
191+
n_queries=self.flags.n_retrieval_queries, # TODO
192+
)
193+
194+
chat_messages = Discussion([system_prompt, query_prompt.prompt])
195+
ans_dict = retry(
196+
self.chat_llm,
197+
chat_messages,
198+
n_retry=self.max_retry,
199+
parser=query_prompt._parse_answer,
200+
)
201+
202+
queries = ans_dict.get("queries", [])
203+
assert len(queries) == self.flags.n_retrieval_queries
204+
205+
# TODO: we should probably propagate these chat_messages to be able to see them in xray
206+
return queries, ans_dict.get("think", None)
207+
171208
def reset(self, seed=None):
172209
self.seed = seed
173210
self.plan = "No plan yet"

src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
It is based on the dynamic_prompting module from the agentlab package.
55
"""
66

7+
import json
78
import logging
89
from dataclasses import dataclass
910
from pathlib import Path
@@ -60,6 +61,7 @@ class GenericPromptFlags(dp.Flags):
6061
add_missparsed_messages: bool = True
6162
max_trunc_itr: int = 20
6263
flag_group: str = None
64+
n_retrieval_queries: int = 3
6365

6466

6567
class MainPrompt(dp.Shrinkable):
@@ -396,3 +398,74 @@ def get_hints_for_task(self, task_name: str) -> str:
396398
print(f"Warning: Error getting hints for task {task_name}: {e}")
397399

398400
return ""
401+
402+
403+
class StepWiseRetrievalPrompt(dp.Shrinkable):
404+
def __init__(
405+
self,
406+
obs_history: list[dict],
407+
actions: list[str],
408+
thoughts: list[str],
409+
obs_flags: dp.ObsFlags,
410+
n_queries: int = 3,
411+
) -> None:
412+
super().__init__()
413+
self.obs_flags = obs_flags
414+
self.n_queries = n_queries
415+
self.history = dp.History(obs_history, actions, None, thoughts, obs_flags)
416+
self.instructions = dp.GoalInstructions(obs_history[-1]["goal_object"])
417+
self.obs = dp.Observation(obs_history[-1], obs_flags)
418+
419+
self.think = dp.Think(visible=True) # To replace with static text maybe
420+
421+
@property
422+
def _prompt(self) -> HumanMessage:
423+
prompt = HumanMessage(self.instructions.prompt)
424+
425+
prompt.add_text(
426+
f"""\
427+
{self.obs.prompt}\
428+
{self.history.prompt}\
429+
"""
430+
)
431+
432+
example_queries = [
433+
"How to sort with multiple columns on the ServiceNow platform?",
434+
"What are the potential challenges of sorting by multiple columns?",
435+
"How to handle sorting by multiple columns in a table?",
436+
"Can I use the filter tool to sort by multiple columns?",
437+
]
438+
439+
example_queries_str = json.dumps(example_queries[: self.n_queries], indent=2)
440+
441+
prompt.add_text(
442+
f"""
443+
# Querying memory
444+
445+
Before choosing an action, let's search our available documentation and memory on how to approach this step.
446+
This could provide valuable hints on how to properly solve this task. Return your answer as follow
447+
<think>chain of thought</think>
448+
<queries>json list of strings</queries> for the queries. Return exactly {self.n_queries}
449+
queries in the list.
450+
451+
# Concrete Example
452+
453+
<think>
454+
I have to sort by client and country. I could use the built-in sort on each column but I'm not sure if
455+
I will be able to sort by both at the same time.
456+
</think>
457+
458+
<queries>
459+
{example_queries_str}
460+
</queries>
461+
"""
462+
)
463+
464+
return self.obs.add_screenshot(prompt)
465+
466+
def _parse_answer(self, text_answer):
467+
ans_dict = parse_html_tags_raise(
468+
text_answer, keys=["think", "queries"], merge_multiple=True
469+
)
470+
ans_dict["queries"] = json.loads(ans_dict.get("queries", "[]"))
471+
return ans_dict

0 commit comments

Comments
 (0)