Skip to content

Commit 87e2510

Browse files
Improve generic agent hinter (#309)
* Make LLM retreival topic index selection more robust
1 parent 02dd9b8 commit 87e2510

File tree

4 files changed

+26
-6
lines changed

4 files changed

+26
-6
lines changed

src/agentlab/agents/generic_agent_hinter/generic_agent.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@ def __init__(
9393

9494
self.flags = flags
9595
if self.flags.hint_db_path is not None:
96-
assert os.path.exists(self.flags.hint_db_path), f"Hint database path {self.flags.hint_db_path} does not exist."
96+
assert os.path.exists(
97+
self.flags.hint_db_path
98+
), f"Hint database path {self.flags.hint_db_path} does not exist."
9799
self.action_set = self.flags.action.action_set.make_action_set()
98100
self._obs_preprocessor = dp.make_obs_preprocessor(flags.obs)
99101

@@ -118,7 +120,9 @@ def get_action(self, obs):
118120

119121
# use those queries to retrieve from the database and pass to prompt if step-level
120122
self.queries = (
121-
self._get_queries()[0] if getattr(self.flags, "hint_level", "episode") == "step" else None
123+
self._get_queries()[0]
124+
if getattr(self.flags, "hint_level", "episode") == "step"
125+
else None
122126
)
123127

124128
# get hints
@@ -204,6 +208,7 @@ def _get_queries(self):
204208
)
205209

206210
chat_messages = Discussion([system_prompt, query_prompt.prompt])
211+
# BUG: Parsing fails multiple times.
207212
ans_dict = retry(
208213
self.chat_llm,
209214
chat_messages,

src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
logger = logging.getLogger(__name__)
2121

22+
2223
@dataclass
2324
class GenericPromptFlags(dp.Flags):
2425
"""
@@ -403,6 +404,8 @@ def _parse_answer(self, text_answer):
403404
ans_dict["queries"] = json.loads(raw_queries)
404405
except Exception as e:
405406
t = text_answer.replace("\n", "\\n")
406-
logger.warning(f"Failed to parse queries: {e}. Queries block content: '{ans_dict['queries']}'. RAW llm answer: '{t}'. Will retry")
407+
logger.warning(
408+
f"Failed to parse queries: {e}. Queries block content: '{ans_dict['queries']}'. RAW llm answer: '{t}'. Will retry"
409+
)
407410
raise e
408411
return ans_dict

src/agentlab/agents/generic_agent_hinter/tmlr_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
use_think_history=True, # gpt-4o config except for this line
2424
use_diff=False,
2525
html_type="pruned_html",
26-
use_screenshot=True,
26+
use_screenshot=False,
2727
use_som=False,
2828
extract_visible_tag=True,
2929
extract_clickable_tag=True,

src/agentlab/utils/hinting.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,14 @@
1212
import pandas as pd
1313
import requests
1414
from agentlab.llm.chat_api import ChatModel
15+
import re
16+
from agentlab.llm.response_api import APIPayload
1517

1618
logger = logging.getLogger(__name__)
1719

1820

1921
class HintsSource:
22+
2023
def __init__(
2124
self,
2225
hint_db_path: str,
@@ -27,7 +30,8 @@ def __init__(
2730
embedder_server: str = "http://localhost:5000",
2831
llm_prompt: str = """We're choosing hints to help solve the following task:\n{goal}.\n
2932
You need to choose the most relevant hints topic from the following list:\n\nHint topics:\n{topics}\n
30-
Choose hint topic for the task and return only its number, e.g. 1. If you don't know the answer, return -1.""",
33+
Choose hint topic for the task and return only its number. Use the following output format:
34+
<choice>index</choice> for e.g. <choice>0</choice> for the topic with index 0. If you don't know the answer, return <choice>-1</choice>""",
3135
) -> None:
3236
self.hint_db_path = hint_db_path
3337
self.hint_retrieval_mode = hint_retrieval_mode
@@ -96,7 +100,15 @@ def choose_hints_llm(self, llm, goal: str, task_name: str) -> list[str]:
96100
else:
97101
response: str = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)])).think
98102
try:
99-
topic_number = json.loads(response)
103+
matches = re.findall(r"<choice>(-?\d+)</choice>", response)
104+
if not matches:
105+
logger.error(f"No choice tags found in LLM response: {response}")
106+
return []
107+
if len(matches) > 1:
108+
logger.warning(
109+
f"LLM selected multiple topics for retrieval using only the first one."
110+
)
111+
topic_number = int(matches[0])
100112
if topic_number < 0 or topic_number >= len(hint_topics):
101113
logger.error(f"Wrong LLM hint id response: {response}, no hints")
102114
return []

0 commit comments

Comments
 (0)