Skip to content

Commit c674094

Browse files
committed
Enhance ToolUseAgent with token counting and improved message handling for observations
1 parent 16cc3cd commit c674094

File tree

1 file changed

+28
-17
lines changed
  • src/agentlab/agents/tool_use_agent

1 file changed

+28
-17
lines changed

src/agentlab/agents/tool_use_agent/agent.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import json
22
import logging
3-
from copy import deepcopy as copy
4-
from dataclasses import asdict, dataclass
3+
from dataclasses import dataclass
54
from typing import TYPE_CHECKING, Any
65

76
import bgym
@@ -103,6 +102,16 @@ def __init__(
103102

104103
self.tools = self.action_set.to_tool_description(api=model_args.api)
105104

105+
# count tools tokens
106+
from agentlab.llm.llm_utils import count_tokens
107+
108+
tool_str = json.dumps(self.tools, indent=2)
109+
print(f"Tool description: {tool_str}")
110+
tool_tokens = count_tokens(tool_str, model_args.model_name)
111+
print(f"Tool tokens: {tool_tokens}")
112+
113+
self.call_ids = []
114+
106115
# self.tools.append(
107116
# {
108117
# "type": "function",
@@ -154,26 +163,28 @@ def get_action(self, obs: Any) -> float:
154163
goal_message.add_image(content["image_url"])
155164
self.messages.append(goal_message)
156165

166+
extra_info = []
167+
168+
extra_info.append(
169+
"Use ControlOrMeta instead of Control and Meta for keyboard shortcuts, to be cross-platform compatible. E.g. use ControlOrMeta for mutliple selection in lists"
170+
)
171+
172+
self.messages.append(MessageBuilder.user().add_text("\n".join(extra_info)))
173+
157174
if self.use_first_obs:
175+
msg = "Here is the first observation."
176+
screenshot_key = "screenshot_tag" if self.tag_screenshot else "screenshot"
158177
if self.tag_screenshot:
159-
message = MessageBuilder.user().add_text(
160-
"Here is the first observation. A red dot on screenshots indicate the previous click action:"
161-
)
162-
message.add_image(image_to_png_base64_url(obs["screenshot_tag"]))
163-
else:
164-
message = MessageBuilder.user().add_text("Here is the first observation:")
165-
message.add_image(image_to_png_base64_url(obs["screenshot"]))
178+
msg += " A red dot on screenshots indicate the previous click action."
179+
message = MessageBuilder.user().add_text(msg)
180+
message.add_image(image_to_png_base64_url(obs[screenshot_key]))
166181
self.messages.append(message)
167182
else:
168183
if obs["last_action_error"] == "":
169-
if self.tag_screenshot:
170-
tool_message = MessageBuilder.tool().add_image(
171-
image_to_png_base64_url(obs["screenshot_tag"])
172-
)
173-
else:
174-
tool_message = MessageBuilder.tool().add_image(
175-
image_to_png_base64_url(obs["screenshot"])
176-
)
184+
screenshot_key = "screenshot_tag" if self.tag_screenshot else "screenshot"
185+
tool_message = MessageBuilder.tool().add_image(
186+
image_to_png_base64_url(obs[screenshot_key])
187+
)
177188
tool_message.add_tool_id(self.previous_call_id)
178189
self.messages.append(tool_message)
179190
else:

0 commit comments

Comments
 (0)