Skip to content

Commit fcffc2e

Browse files
enable reprompt tool use agent from controller
1 parent af6d888 commit fcffc2e

File tree

4 files changed

+106
-49
lines changed

4 files changed

+106
-49
lines changed

src/agentlab/agents/tool_use_agent/hint_db.csv

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,26 @@ June 11,miniwob.drag-items,30,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7
1616
June 18,miniwob.count-shape,23,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,allac,Shape and letters size comparison in miniwob,"Shapes or items have different colors and different size. Size is relative to the other objects in the white area and is either ""large"" or ""small"". Shapes that are larger than the average shape or letter are considered ""large"". Others are ""small""."
1717
June 18,miniwob.count-shape,23,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,allac,communicate answer in miniwob,Answer by clicking one of the buttons describing multiple choices.
1818
June 18,miniwob.count-shape,23,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,allac,Simbols of colors in miniwob,"Colors a distinct in this task, e.g., cyan is not a type of blue. "
19-
June 18,miniwob.form-sequence-2,23,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,allac,Reporting results in miniwob,Make sure to click submit to finish the task.
19+
June 18,miniwob.form-sequence-2,23,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,allac,Reporting results in miniwob,Make sure to click submit to finish the task.
20+
July 16,workarena.servicenow.sort-asset-list,406,gpt-4-1,ToolUseAgent-gpt-4-1,workarena,workarena,patricebechard,Sorting lists in ServiceNow,"1. **Navigate to Your Table/List**
21+
22+
* For example, go to **Incident > All** or any other table you want to view.
23+
24+
2. **Sort by One or Multiple Columns**
25+
26+
* `click` on the ""show / hide filter"" button (funnel icon) at the top left of the page to open the filter row.
27+
* Repeat the following steps for each column you want to sort by:
28+
* `click` on the ""Add Sort"" button to add a new sort filter. This will create a new ordering filter row with two comboboxes under the heading ""Order results by the following fields"".
29+
* `fill` the first combobox with the appropriate field name you want to sort by. MAKE SURE to use the exact field name provided.
30+
* `press` Enter after typing the field name. It is VERY IMPORTANT that you do this before doing anything else. DO NOT click on the run filter button before having confirmed your choice by explicitly pressing ENTER.
31+
* `select_option` for the appropriate ordering between ascending (a to z) or descending (z to a) in the second combobox.
32+
* Once all sort filters have been added, `click` the ""Run filter"" button to apply the sort.
33+
34+
Notes:
35+
* NEVER directly sort the columns using the table header.
36+
* NEVER add columns via the Personalize List menu.
37+
38+
3. **Resetting or Clearing Sorting**
39+
40+
* To reset sorting, click another column, or click again to toggle.
41+
* In the filter bar, you may see a ""Sorted by..."" indicator—clear or change it as needed."

src/agentlab/agents/tool_use_agent/tool_use_agent.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,14 @@
88

99
import bgym
1010
import pandas as pd
11-
from bgym import Benchmark as BgymBenchmark
12-
from browsergym.core.observation import extract_screenshot
13-
from browsergym.utils.obs import (
14-
flatten_axtree_to_str,
15-
flatten_dom_to_str,
16-
overlay_som,
17-
prune_html,
18-
)
19-
2011
from agentlab.agents.agent_args import AgentArgs
2112
from agentlab.benchmarks.abstract_env import AbstractBenchmark as AgentLabBenchmark
2213
from agentlab.benchmarks.osworld import OSWorldActionSet
2314
from agentlab.llm.base_api import BaseModelArgs
2415
from agentlab.llm.llm_utils import image_to_png_base64_url
2516
from agentlab.llm.response_api import (
2617
APIPayload,
18+
AzureOpenAIResponseModelArgs,
2719
ClaudeResponseModelArgs,
2820
LLMOutput,
2921
MessageBuilder,
@@ -33,6 +25,14 @@
3325
ToolCalls,
3426
)
3527
from agentlab.llm.tracking import cost_tracker_decorator
28+
from bgym import Benchmark as BgymBenchmark
29+
from browsergym.core.observation import extract_screenshot
30+
from browsergym.utils.obs import (
31+
flatten_axtree_to_str,
32+
flatten_dom_to_str,
33+
overlay_som,
34+
prune_html,
35+
)
3636

3737

3838
@dataclass
@@ -43,8 +43,8 @@ def _init(self):
4343

4444
def make(self) -> "Block":
4545
"""Returns a copy so the init can start adding some stuff to `self` without changing the
46-
original datatclass that should only contain a config.
47-
The aim is avoid having 2 calss definition for each block, e.g. Block and BlockArgs.
46+
original dataclass that should only contain a config.
47+
The aim is avoid having 2 class definitions for each block, e.g. Block and BlockArgs.
4848
4949
Returns:
5050
Block: A copy of the current block instance with initialization applied.
@@ -387,7 +387,6 @@ def __init__(
387387
self.config.action_subsets, multiaction=self.config.multiaction # type: ignore
388388
)
389389
self.tools = self.action_set.to_tool_description(api=model_args.api)
390-
391390
self.call_ids = []
392391

393392
self.llm = model_args.make_model()
@@ -595,8 +594,8 @@ def get_action(self, obs: Any) -> float:
595594
task_hint=TaskHint(use_task_hint=True),
596595
keep_last_n_obs=None,
597596
multiaction=True, # whether to use multi-action or not
598-
# action_subsets=("bid",),
599-
action_subsets=("coord"),
597+
action_subsets=("bid",),
598+
# action_subsets=("coord"),
600599
# action_subsets=("coord", "bid"),
601600
)
602601

src/agentlab/analyze/agent_controller.py

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
)
2626
from agentlab.experiments.exp_utils import RESULTS_DIR
2727
from agentlab.experiments.loop import ExpArgs, StepInfo, save_package_versions
28-
from agentlab.llm.llm_utils import Discussion
28+
from agentlab.llm.response_api import LLMOutput
2929
from bgym import DEFAULT_BENCHMARKS
3030
from dotenv import load_dotenv
3131
from transformers import AutoTokenizer
@@ -188,7 +188,12 @@ def step_agent_history(action, action_info):
188188
st.session_state.action_history.append(action)
189189
st.session_state.action_info_history.append(action_info)
190190
st.session_state.thought_history.append(action_info.think)
191-
st.session_state.prompt_history.append(get_prompt(action_info))
191+
if isinstance(st.session_state.agent, GenericAgent):
192+
st.session_state.prompt_history.append(get_prompt(action_info))
193+
elif isinstance(st.session_state.agent, ToolUseAgent):
194+
st.session_state.prompt_history.append(
195+
"\n".join([elem.to_markdown() for elem in st.session_state.agent.discussion.flatten()])
196+
)
192197

193198
# HACK: memory history can only be obtained via the agent
194199
if isinstance(st.session_state.agent, GenericAgent):
@@ -229,10 +234,31 @@ def revert_agent_history():
229234

230235
def revert_agent_state():
231236
logger.info("Reverting agent state")
232-
st.session_state.agent.obs_history.pop()
233-
st.session_state.agent.actions.pop()
234-
st.session_state.agent.thoughts.pop()
235-
st.session_state.agent.memories.pop()
237+
if isinstance(st.session_state.agent, GenericAgent):
238+
st.session_state.agent.obs_history.pop()
239+
st.session_state.agent.actions.pop()
240+
st.session_state.agent.thoughts.pop()
241+
st.session_state.agent.memories.pop()
242+
elif isinstance(st.session_state.agent, ToolUseAgent):
243+
num_groups = len(st.session_state.agent.discussion.groups)
244+
if num_groups == 3:
245+
# start from blank state
246+
st.session_state.agent.discussion.groups = []
247+
st.session_state.agent.last_response = LLMOutput()
248+
st.session_state.agent._responses = []
249+
elif num_groups > 3:
250+
# get rid of the last group (last action), and remove everything from the other previous group except for the action
251+
st.session_state.agent.discussion.groups.pop()
252+
last_group = copy.deepcopy(st.session_state.agent.discussion.groups[-1])
253+
last_group.summary = None
254+
last_group.messages = last_group.messages[:0] # remove everything from last group
255+
st.session_state.agent.discussion.groups[-1] = last_group
256+
st.session_state.agent._responses.pop()
257+
st.session_state.agent.last_response = copy.deepcopy(
258+
st.session_state.agent._responses[-1]
259+
)
260+
else:
261+
raise Exception("Invalid number of groups")
236262

237263

238264
def restore_env_history(step: int):
@@ -534,9 +560,17 @@ def load_session(exp_files):
534560
st.session_state.action_history.append(step_info.action)
535561
st.session_state.action_info_history.append(step_info.agent_info)
536562
st.session_state.thought_history.append(step_info.agent_info.get("think", None))
537-
st.session_state.prompt_history.append(get_prompt(step_info.agent_info))
538563
if isinstance(st.session_state.agent, GenericAgent):
539564
st.session_state.memory_history.append(step_info.agent_info.get("memory", None))
565+
st.session_state.prompt_history.append(get_prompt(step_info.agent_info))
566+
elif isinstance(st.session_state.agent, ToolUseAgent):
567+
st.session_state.prompt_history.append(
568+
"\n".join(
569+
[elem.to_markdown() for elem in st.session_state.agent.discussion.flatten()]
570+
)
571+
)
572+
else:
573+
raise ValueError(f"Unknown agent type: {type(st.session_state.agent)}")
540574
st.session_state.obs_history.append(step_info.obs)
541575
st.session_state.reward_history.append(step_info.reward)
542576
st.session_state.terminated_history.append(step_info.terminated)
@@ -573,7 +607,8 @@ def clean_session():
573607
def prepare_agent():
574608
st.session_state.agent_args.prepare()
575609
st.session_state.agent = st.session_state.agent_args.make_agent()
576-
st.session_state.agent.set_task_name(st.session_state.task)
610+
if isinstance(st.session_state.agent, ToolUseAgent):
611+
st.session_state.agent.set_task_name(st.session_state.task)
577612

578613

579614
def set_environment_info():
@@ -863,9 +898,9 @@ def set_prompt_modifier():
863898
st.session_state.agent.config.obs.use_tabs = st.checkbox(
864899
"Use tabs", value=st.session_state.agent.config.obs.use_tabs
865900
)
866-
st.session_state.agent.config.obs.add_mouse_pointer = st.checkbox(
867-
"Add mouse pointer", value=st.session_state.agent.config.obs.add_mouse_pointer
868-
)
901+
# st.session_state.agent.config.obs.add_mouse_pointer = st.checkbox(
902+
# "Add mouse pointer", value=st.session_state.agent.config.obs.add_mouse_pointer
903+
# )
869904
st.session_state.agent.config.obs.use_zoomed_webpage = st.checkbox(
870905
"Use zoomed webpage", value=st.session_state.agent.config.obs.use_zoomed_webpage
871906
)
@@ -1107,7 +1142,14 @@ def set_axtree_tab():
11071142

11081143

11091144
def set_prompt_tab():
1110-
st.code(st.session_state.prompt_history[-1], language=None, wrap_lines=True)
1145+
if isinstance(st.session_state.agent, GenericAgent):
1146+
st.code(st.session_state.prompt_history[-1], language=None, wrap_lines=True)
1147+
elif isinstance(st.session_state.agent, ToolUseAgent):
1148+
st.markdown(st.session_state.prompt_history[-1])
1149+
1150+
st.markdown(f"## Last summary:\n{st.session_state.agent.discussion.get_last_summary()}")
1151+
else:
1152+
raise ValueError(f"Unknown agent type: {type(st.session_state.agent)}")
11111153

11121154

11131155
def set_previous_steps_tab():

src/agentlab/llm/response_api.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -589,29 +589,26 @@ class AzureOpenAIResponseModel(OpenAIResponseModel):
589589
def __init__(
590590
self,
591591
model_name: str,
592+
base_url: Optional[str] = None,
592593
api_key: Optional[str] = None,
593-
temperature: float = 0.5,
594-
max_tokens: int = 100,
595-
extra_kwargs: Optional[Dict[str, Any]] = None,
596-
**kwargs,
594+
temperature: float | None = None,
595+
max_tokens: int | None = 100,
597596
):
598597
api_key = os.getenv("AZURE_OPENAI_API_KEY")
599-
self.tools = kwargs.pop("tools", None)
600-
logging.info(f"Tools: {self.tools}")
601-
super().__init__(
602-
model_name=model_name,
603-
api_key=api_key,
604-
temperature=temperature,
605-
max_tokens=max_tokens,
606-
extra_kwargs=extra_kwargs,
607-
**kwargs,
608-
)
609-
# azure client takes extra kwargs
610-
self.client = OpenAI(
611-
api_key=api_key,
612-
base_url=urljoin(os.getenv("AZURE_OPENAI_ENDPOINT"), "openai/v1"),
613-
default_query={"api-version": "preview"},
598+
base_url = urljoin(os.getenv("AZURE_OPENAI_ENDPOINT"), "openai/v1")
599+
self.action_space_as_tools = True # this should be a config
600+
super().__init__( # This is passed to BaseModel
601+
model_name=model_name, api_key=api_key, temperature=temperature, max_tokens=max_tokens
614602
)
603+
client_args = {}
604+
if base_url is not None:
605+
client_args["base_url"] = base_url
606+
if api_key is not None:
607+
client_args["api_key"] = api_key
608+
client_args["default_query"] = {"api-version": "preview"}
609+
self.client = OpenAI(**client_args)
610+
# Init pricing tracker after super() so that all attributes have been set.
611+
self.init_pricing_tracker(pricing_api="openai") # Use the PricingMixin
615612

616613

617614
class OpenAIChatCompletionModel(BaseModelWithPricing):
@@ -958,9 +955,6 @@ def make_model(self, extra_kwargs=None, **kwargs):
958955
model_name=self.model_name,
959956
temperature=self.temperature,
960957
max_tokens=self.max_new_tokens,
961-
extra_kwargs=extra_kwargs,
962-
pricing_api="openai",
963-
**kwargs,
964958
)
965959

966960

0 commit comments

Comments
 (0)