|
1 | 1 | import json |
2 | 2 | import logging |
3 | | -from copy import deepcopy as copy |
4 | | -from dataclasses import asdict, dataclass |
| 3 | +from dataclasses import dataclass |
5 | 4 | from typing import TYPE_CHECKING, Any |
6 | 5 |
|
7 | 6 | import bgym |
@@ -103,6 +102,16 @@ def __init__( |
103 | 102 |
|
104 | 103 | self.tools = self.action_set.to_tool_description(api=model_args.api) |
105 | 104 |
|
| 105 | + # count tools tokens |
| 106 | + from agentlab.llm.llm_utils import count_tokens |
| 107 | + |
| 108 | + tool_str = json.dumps(self.tools, indent=2) |
| 109 | + print(f"Tool description: {tool_str}") |
| 110 | + tool_tokens = count_tokens(tool_str, model_args.model_name) |
| 111 | + print(f"Tool tokens: {tool_tokens}") |
| 112 | + |
| 113 | + self.call_ids = [] |
| 114 | + |
106 | 115 | # self.tools.append( |
107 | 116 | # { |
108 | 117 | # "type": "function", |
@@ -154,26 +163,28 @@ def get_action(self, obs: Any) -> float: |
154 | 163 | goal_message.add_image(content["image_url"]) |
155 | 164 | self.messages.append(goal_message) |
156 | 165 |
|
| 166 | + extra_info = [] |
| 167 | + |
| 168 | + extra_info.append( |
| 169 | + "Use ControlOrMeta instead of Control and Meta for keyboard shortcuts, to be cross-platform compatible. E.g. use ControlOrMeta for mutliple selection in lists" |
| 170 | + ) |
| 171 | + |
| 172 | + self.messages.append(MessageBuilder.user().add_text("\n".join(extra_info))) |
| 173 | + |
157 | 174 | if self.use_first_obs: |
| 175 | + msg = "Here is the first observation." |
| 176 | + screenshot_key = "screenshot_tag" if self.tag_screenshot else "screenshot" |
158 | 177 | if self.tag_screenshot: |
159 | | - message = MessageBuilder.user().add_text( |
160 | | - "Here is the first observation. A red dot on screenshots indicate the previous click action:" |
161 | | - ) |
162 | | - message.add_image(image_to_png_base64_url(obs["screenshot_tag"])) |
163 | | - else: |
164 | | - message = MessageBuilder.user().add_text("Here is the first observation:") |
165 | | - message.add_image(image_to_png_base64_url(obs["screenshot"])) |
| 178 | + msg += " A red dot on screenshots indicate the previous click action." |
| 179 | + message = MessageBuilder.user().add_text(msg) |
| 180 | + message.add_image(image_to_png_base64_url(obs[screenshot_key])) |
166 | 181 | self.messages.append(message) |
167 | 182 | else: |
168 | 183 | if obs["last_action_error"] == "": |
169 | | - if self.tag_screenshot: |
170 | | - tool_message = MessageBuilder.tool().add_image( |
171 | | - image_to_png_base64_url(obs["screenshot_tag"]) |
172 | | - ) |
173 | | - else: |
174 | | - tool_message = MessageBuilder.tool().add_image( |
175 | | - image_to_png_base64_url(obs["screenshot"]) |
176 | | - ) |
| 184 | + screenshot_key = "screenshot_tag" if self.tag_screenshot else "screenshot" |
| 185 | + tool_message = MessageBuilder.tool().add_image( |
| 186 | + image_to_png_base64_url(obs[screenshot_key]) |
| 187 | + ) |
177 | 188 | tool_message.add_tool_id(self.previous_call_id) |
178 | 189 | self.messages.append(tool_message) |
179 | 190 | else: |
|
0 commit comments