Skip to content

Commit 2e275cd

Browse files
exact goal loading in the tool-use-agent
1 parent e91ab05 commit 2e275cd

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

src/agentlab/agents/tool_use_agent/tool_use_agent.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -325,12 +325,17 @@ def _init(self):
325325
embedder_server=self.embedder_server,
326326
)
327327

328-
def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict:
328+
def apply(self, llm, discussion: StructuredDiscussion, obs: dict, task_name: str) -> dict:
329329
if not self.use_task_hint:
330330
return {}
331331

332-
goal = "\n".join([c.get("text", "") for c in discussion.groups[0].messages[1].content])
333-
task_hints = self.hints_source.choose_hints(llm, task_name, goal)
332+
# goal = "\n".join([c.get("text", "") for c in discussion.groups[0].messages[1].content])
333+
try:
334+
goal_text = obs["goal_object"][0]["text"]
335+
except (KeyError, IndexError):
336+
Warning("Goal text not found in observation")
337+
goal_text = ""
338+
task_hints = self.hints_source.choose_hints(llm, task_name, goal_text)
334339

335340
hints = []
336341
for hint in task_hints:
@@ -472,7 +477,7 @@ def get_action(self, obs: Any) -> float:
472477

473478
self.config.summarizer.apply_init(self.llm, self.discussion)
474479
self.config.general_hints.apply(self.llm, self.discussion)
475-
self.task_hint.apply(self.llm, self.discussion, self.task_name)
480+
self.task_hint.apply(self.llm, self.discussion, obs=obs, task_name=self.task_name)
476481

477482
self.discussion.new_group()
478483

src/agentlab/utils/hinting.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ def __init__(
5252
self.hint_db_path,
5353
header=0,
5454
index_col=None,
55-
dtype=str,
5655
converters={
5756
"trace_paths_json": lambda x: json.loads(x) if pd.notna(x) else [],
5857
"source_trace_goals": lambda x: json.loads(x) if pd.notna(x) else [],

0 commit comments

Comments
 (0)