From 6ae2c993af9e2cd03237cdc4cc8c3e29e77933e6 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Tue, 10 Jun 2025 13:35:43 -0400 Subject: [PATCH 01/24] add azure agents --- pyproject.toml | 1 + src/agentlab/agents/generic_agent/__init__.py | 16 ++++++++ .../agents/generic_agent/agent_configs.py | 38 +++++++++++++++++++ src/agentlab/llm/llm_configs.py | 33 ++++++++++++++++ 4 files changed, 88 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 1292836a..ae49f143 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,3 +58,4 @@ exclude = ''' [project.scripts] agentlab-assistant = "agentlab.ui_assistant:main" agentlab-xray = "agentlab.analyze.agent_xray:main" +agentlab-controller = "agentlab.analyze.run_agentlab_controller:main" diff --git a/src/agentlab/agents/generic_agent/__init__.py b/src/agentlab/agents/generic_agent/__init__.py index cb5bbb7f..9aecbb5f 100644 --- a/src/agentlab/agents/generic_agent/__init__.py +++ b/src/agentlab/agents/generic_agent/__init__.py @@ -26,6 +26,14 @@ AGENT_o3_MINI, FLAGS_GPT_4o, GenericAgentArgs, + AGENT_AZURE_4o_MINI, + AGENT_AZURE_4o, + AGENT_AZURE_4o_VISION, + AGENT_AZURE_4o_MINI_VISION, + AGENT_AZURE_41, + AGENT_AZURE_41_MINI, + AGENT_AZURE_41_VISION, + AGENT_AZURE_41_MINI_VISION, ) __all__ = [ @@ -46,4 +54,12 @@ "AGENT_4o_VISION", "AGENT_4o_MINI_VISION", "AGENT_CLAUDE_SONNET_35_VISION", + "AGENT_AZURE_4o_MINI", + "AGENT_AZURE_4o", + "AGENT_AZURE_4o_VISION", + "AGENT_AZURE_4o_MINI_VISION", + "AGENT_AZURE_41", + "AGENT_AZURE_41_MINI", + "AGENT_AZURE_41_VISION", + "AGENT_AZURE_41_MINI_VISION", ] diff --git a/src/agentlab/agents/generic_agent/agent_configs.py b/src/agentlab/agents/generic_agent/agent_configs.py index f50367d8..17728a82 100644 --- a/src/agentlab/agents/generic_agent/agent_configs.py +++ b/src/agentlab/agents/generic_agent/agent_configs.py @@ -350,3 +350,41 @@ chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4o-2024-05-13"], flags=DEFAULT_RS_FLAGS, ) + + +AGENT_AZURE_4o_MINI = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-4o-mini"], + flags=FLAGS_GPT_4o, +) +AGENT_AZURE_4o = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-4o"], + flags=FLAGS_GPT_4o, +) +AGENT_AZURE_41 = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-4.1"], + flags=FLAGS_GPT_4o, +) +AGENT_AZURE_41_MINI = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-4.1-mini"], + flags=FLAGS_GPT_4o, +) + +AGENT_AZURE_4o_VISION = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-4o"], + flags=FLAGS_GPT_4o_VISION, +) + +AGENT_AZURE_4o_MINI_VISION = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-4o-mini"], + flags=FLAGS_GPT_4o_VISION, +) + +AGENT_AZURE_41_VISION = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-4.1"], + flags=FLAGS_GPT_4o_VISION, +) + +AGENT_AZURE_41_MINI_VISION = GenericAgentArgs( + chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-4.1-mini"], + flags=FLAGS_GPT_4o_VISION, +) diff --git a/src/agentlab/llm/llm_configs.py b/src/agentlab/llm/llm_configs.py index 12f1dd27..fb461d81 100644 --- a/src/agentlab/llm/llm_configs.py +++ b/src/agentlab/llm/llm_configs.py @@ -242,4 +242,37 @@ max_new_tokens=64_000, temperature=1e-1, ), + ### Azure + "azure/gpt-4o-mini": AzureModelArgs( + model_name="gpt-4o-mini", + # deployment_name="gpt-4o-mini-2024-07-18", + max_total_tokens=128_000, + max_input_tokens=128_000, + max_new_tokens=16_384, + vision_support=True, + ), + "azure/gpt-4o": AzureModelArgs( + model_name="gpt-4o", + # deployment_name="gpt-4o-mini-2024-07-18", + max_total_tokens=128_000, + max_input_tokens=128_000, + max_new_tokens=16_384, + vision_support=True, + ), + "azure/gpt-4.1": AzureModelArgs( + model_name="gpt-4.1", + # deployment_name="gpt-4.1", + max_total_tokens=128_000, + max_input_tokens=128_000, + max_new_tokens=16_384, + vision_support=True, + ), + "azure/gpt-4.1-mini": AzureModelArgs( + model_name="gpt-4.1-mini", + # deployment_name="gpt-4.1-mini", + max_total_tokens=128_000, + max_input_tokens=128_000, + max_new_tokens=16_384, + vision_support=True, + ), } From 73745ceff4a4610c5f33952533aab0612062612d Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Thu, 12 Jun 2025 00:30:51 -0400 Subject: [PATCH 02/24] add agentlab server and agentlab controller --- pyproject.toml | 1 + src/agentlab/analyze/agent_controller.py | 475 ++++++++++++++++++ .../analyze/run_agentlab_controller.py | 14 + src/agentlab/analyze/server.py | 426 ++++++++++++++++ 4 files changed, 916 insertions(+) create mode 100644 src/agentlab/analyze/agent_controller.py create mode 100644 src/agentlab/analyze/run_agentlab_controller.py create mode 100644 src/agentlab/analyze/server.py diff --git a/pyproject.toml b/pyproject.toml index ae49f143..d040e0b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,3 +59,4 @@ exclude = ''' agentlab-assistant = "agentlab.ui_assistant:main" agentlab-xray = "agentlab.analyze.agent_xray:main" agentlab-controller = "agentlab.analyze.run_agentlab_controller:main" +agentlab-server = "agentlab.analyze.server:main" \ No newline at end of file diff --git a/src/agentlab/analyze/agent_controller.py b/src/agentlab/analyze/agent_controller.py new file mode 100644 index 00000000..9e9941f1 --- /dev/null +++ b/src/agentlab/analyze/agent_controller.py @@ -0,0 +1,475 @@ +import base64 +import copy +import importlib +import logging +from io import BytesIO +from pathlib import Path +import requests +import numpy as np +import PIL.Image +import streamlit as st +from agentlab.agents.generic_agent import __all__ as ALL_AGENTS +from agentlab.experiments.exp_utils import RESULTS_DIR +from bgym import DEFAULT_BENCHMARKS +from dotenv import load_dotenv +from agentlab.llm.llm_utils import Discussion +from transformers import AutoTokenizer +from datetime import datetime + +# used to display prompt. simple chat template from apache 2.0 model +tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +load_dotenv() + +DEFAULT_AGENT = "AGENT_AZURE_4o" +DEFAULT_BENCHMARK = "workarena_l1" + +SERVER_URL = "http://127.0.0.1:8000" + + +class IgnoreMessageFilter(logging.Filter): + def filter(self, record): + return "but it does not exist!" not in record.getMessage() + + +streamlit_logger = st.watcher.local_sources_watcher._LOGGER +streamlit_logger.setLevel(logging.ERROR) + + +def get_import_path(obj): + return f"{obj.__module__}.{obj.__qualname__}" + + +def setup_sidebar(): + with st.sidebar: + st.markdown( + """ +# AgentLab Controller + +AgentLab Controller is a tool used to help control and debug an agent deployed in an environment. + +AgentLab Controller works by connecting a Streamlit UI that handles the agent to a FastAPI backend server that handles the environment. + +--- + +## Instructions + +1. ⚙️ Setup the task + - Select an agent, benchmark, task, and subtask you want to work on. + - Select "🔄" to reset the environment. This includes resetting the environment server. + - Select "▶️" to start the environment. This will start the environment by opening a browser in the background. This step might take some time + +2. 🎮 Control the environment + - Look at the goal set for the task, the thought of the model, and the action taken. + - If the action looks right, select the "▶️ Next Step" button to step the environment. + + The action will then be executed and the environment will be updated. + - If the action is wrong and you want to re-prompt, select the "🔄 Regenerate Action". + + You can also expand the "Prompt Modifier" menu to change the prompt used to generate the thoughts / actions. + - If you want to backtrack and undo the previous actions, select the "⬅️ Previous Step" button. + + Note: This is a slow process as we need to reset the environment server and perform the previous actions one by one. + +3. 🔎 Investigate the environment + - Look at the screenshot of the current environment state + - Verify that the action selected by the model matches the AxTree + - Ensure that the prompt is properly build. If there are issues with the prompt yielding the wrong action, modify them using the "Prompt Modifier" above. + """ + ) + + +def set_session_state(): + + # args used to instantiate agent / environment + if "has_submitted_configs" not in st.session_state: + st.session_state.has_submitted_configs = False + if "agent_args" not in st.session_state: + st.session_state.agent_args = None + if "benchmark" not in st.session_state: + st.session_state.benchmark = None + if "task" not in st.session_state: + st.session_state.task = None + if "subtask" not in st.session_state: + st.session_state.subtask = None + + if "agent" not in st.session_state: + st.session_state.agent = None + if "environment" not in st.session_state: + st.session_state.environment = None + if "action" not in st.session_state: + st.session_state.action = None + if "action_info" not in st.session_state: + st.session_state.action_info = None + if "actions_history" not in st.session_state: + st.session_state.actions_history = None + if "obs_history" not in st.session_state: + st.session_state.obs_history = None + + if "has_clicked_prev" not in st.session_state: + st.session_state.has_clicked_prev = False + if "has_clicked_next" not in st.session_state: + st.session_state.has_clicked_next = False + + +def select_agent(): + """Dropdown to select an agent.""" + agent_str = st.selectbox("Select Agent", ALL_AGENTS, index=ALL_AGENTS.index(DEFAULT_AGENT)) + agents_module = importlib.import_module("agentlab.agents.generic_agent") + agent = getattr(agents_module, agent_str) + return agent + + +def select_benchmark() -> str: + """Dropdown to select a benchmark.""" + all_benchmarks = list(DEFAULT_BENCHMARKS.keys()) + benchmark_str = st.selectbox("Select Benchmark", all_benchmarks, index=all_benchmarks.index(DEFAULT_BENCHMARK)) + return benchmark_str + + +def select_task(benchmark): + """Dropdown to select a task based on the benchmark.""" + all_tasks = sorted(list(set([elem.task_name for elem in benchmark.env_args_list]))) + task_str = st.selectbox("Select Task", all_tasks) + return task_str + + +def select_subtask(benchmark, task_str) -> str: + """Dropdown to select a subtask based on the task name.""" + all_subtasks = sorted([str(elem.task_seed) for elem in benchmark.env_args_list if elem.task_name == task_str]) + subtask_str = st.selectbox("Select Subtask", all_subtasks) + return subtask_str + + +def set_task_selector(): + """Create task selector form. Allows the user to select the agent, benchmark, task, and subtask to run.""" + with st.form("Task Selector"): + col1, col2, col3, col4, col5, col6 = st.columns([2, 2, 4, 2, 1, 1], vertical_alignment="bottom") + with col1: + selected_agent_args = select_agent() + with col2: + selected_benchmark_str = select_benchmark() + selected_benchmark = DEFAULT_BENCHMARKS[selected_benchmark_str]() + with col3: + selected_task_str = select_task(selected_benchmark) + with col4: + selected_subtask_str = select_subtask(selected_benchmark, selected_task_str) + with col5: + if st.form_submit_button("🔄", use_container_width=True): + clean_session() + with col6: + if st.form_submit_button("▶️", use_container_width=True): + + # saving configs related to agent and task + st.session_state.has_submitted_configs = True + st.session_state.agent_args = selected_agent_args + st.session_state.benchmark = selected_benchmark_str + st.session_state.task = selected_task_str + st.session_state.subtask = selected_subtask_str + + # Set empty state tracker + st.session_state.current_action = None + st.session_state.last_obs = None + st.session_state.actions_history = [] + st.session_state.obs_history = [] + + prepare_agent() + set_environment_info() + reset_environment() + + +def clean_session(): + logger.info("Cleaning session...") + start = datetime.now() + requests.post(f"{SERVER_URL}/unset_info") + requests.post(f"{SERVER_URL}/close") + for key in list(st.session_state.keys()): + del st.session_state[key] + end = datetime.now() + logger.info(f"Done in {end - start}") + st.rerun() + + +def prepare_agent(): + logger.info("Preparing agent...") + start = datetime.now() + st.session_state.agent_args.prepare() + st.session_state.agent = st.session_state.agent_args.make_agent() + end = datetime.now() + logger.info(f"Done in {end - start}") + + +def set_environment_info(): + logger.info("Setting environment info...") + start = datetime.now() + action_mapping_fn = get_import_path(st.session_state.agent.action_set.to_python_code) + payload = { + "benchmark_name": st.session_state.benchmark, + "task_name": st.session_state.task, + "seed": st.session_state.subtask, + "action_mapping_fn": action_mapping_fn, + "exp_dir": str(RESULTS_DIR), + } + resp = requests.post(f"{SERVER_URL}/set_info", json=payload) + if resp.status_code != 200 or resp.json().get("status") != "success": + st.error(resp.json()) + end = datetime.now() + logger.info(f"Done in {end - start}") + + +def reset_environment(): + logger.info("Restarting environment...") + start = datetime.now() + resp = requests.post(f"{SERVER_URL}/reset") + end = datetime.now() + logger.info(f"Done request in {end - start}") + start = datetime.now() + if resp.status_code != 200 or resp.json().get("status") != "success": + print(resp.status_code) + print(resp.json()["status"]) + print(resp.json()["message"]) + response_json = resp.json() + if "obs" in response_json: + if "screenshot" in response_json["obs"]: + screenshot_data = response_json["obs"]["screenshot"] + # convert base64 to numpy array + screenshot = np.frombuffer(base64.b64decode(screenshot_data["data"]), dtype=np.dtype(screenshot_data["dtype"])) + screenshot = screenshot.reshape(screenshot_data["shape"]) + response_json["obs"]["screenshot"] = screenshot + if st.session_state.agent.obs_preprocessor: + response_json["obs"] = st.session_state.agent.obs_preprocessor(response_json["obs"]) + st.session_state.last_obs = response_json["obs"] + end = datetime.now() + logger.info(f"Done postproc in {end - start}") + + +def step_environment(action): + logger.info("Stepping environment...") + start = datetime.now() + payload = {"action": action} + resp = requests.post(f"{SERVER_URL}/step", json=payload) + if resp.status_code != 200 or resp.json().get("status") != "success": + print(resp.status_code) + print(resp.json()["status"]) + print(resp.json()["message"]) + response_json = resp.json() + if "obs" in response_json: + if "screenshot" in response_json["obs"]: + screenshot_data = response_json["obs"]["screenshot"] + # convert base64 to numpy array + screenshot = np.frombuffer(base64.b64decode(screenshot_data["data"]), dtype=np.dtype(screenshot_data["dtype"])) + screenshot = screenshot.reshape(screenshot_data["shape"]) + response_json["obs"]["screenshot"] = screenshot + if st.session_state.agent.obs_preprocessor: + response_json["obs"] = st.session_state.agent.obs_preprocessor(response_json["obs"]) + st.session_state.last_obs = response_json["obs"] + st.session_state.action = None + st.session_state.action_info = None + end = datetime.now() + logger.info(f"Done in {end - start}") + + +def restore_environment(): + reset_environment() + for action in st.session_state.actions_history: + step_environment(action) + + +def get_action(): + logger.info("Getting action...") + start = datetime.now() + action, info = st.session_state.agent.get_action(copy.deepcopy(st.session_state.last_obs)) + st.session_state.action = copy.deepcopy(action) + st.session_state.action_info = copy.deepcopy(info) + end = datetime.now() + logger.info(f"Done in {end - start}") + + +def set_agent_state_box(): + # set agent state and goal box + with st.container(): + col1, col2, col3 = st.columns([1, 1, 1]) + with col1: + with st.container(border=True, height=250): + st.markdown("**Goal**") + st.code(st.session_state.agent.obs_history[-1]["goal"], wrap_lines=True, language=None, height=175) + with col2: + with st.container(border=True, height=250): + st.markdown("**Think**") + st.code(st.session_state.action_info.think, wrap_lines=True, language=None, height=175) + with col3: + with st.container(border=True, height=250): + st.markdown("**Action**") + st.code(st.session_state.action, wrap_lines=True, language="python", height=175) + + +def set_prompt_modifier(): + with st.expander("**Prompt Modifier**", expanded=False): + # st.write(st.session_state.agent.flags) + st.markdown("**Observation Flags**") + col1, col2, col3, col4, col5, col6 = st.columns([1, 1, 1, 1, 1, 1]) + with col1: + st.session_state.agent.flags.obs.use_html = st.checkbox("use_html", value=st.session_state.agent.flags.obs.use_html) + st.session_state.agent.flags.obs.use_action_history = st.checkbox( + "use_action_history", value=st.session_state.agent.flags.obs.use_action_history + ) + with col2: + st.session_state.agent.flags.obs.use_ax_tree = st.checkbox("use_ax_tree", value=st.session_state.agent.flags.obs.use_ax_tree) + st.session_state.agent.flags.obs.use_think_history = st.checkbox( + "use_think_history", value=st.session_state.agent.flags.obs.use_think_history + ) + with col3: + st.session_state.agent.flags.obs.use_focused_element = st.checkbox( + "use_focused_element", value=st.session_state.agent.flags.obs.use_focused_element + ) + st.session_state.agent.flags.obs.use_diff = st.checkbox("use_diff", value=st.session_state.agent.flags.obs.use_diff) + with col4: + st.session_state.agent.flags.obs.use_error_logs = st.checkbox( + "use_error_logs", value=st.session_state.agent.flags.obs.use_error_logs + ) + st.session_state.agent.flags.obs.use_screenshot = st.checkbox( + "use_screenshot", value=st.session_state.agent.flags.obs.use_screenshot + ) + with col5: + st.session_state.agent.flags.obs.use_history = st.checkbox("use_history", value=st.session_state.agent.flags.obs.use_history) + st.session_state.agent.flags.obs.use_som = st.checkbox("use_som", value=st.session_state.agent.flags.obs.use_som) + with col6: + st.session_state.agent.flags.obs.use_past_error_logs = st.checkbox( + "use_past_error_logs", value=st.session_state.agent.flags.obs.use_past_error_logs + ) + st.session_state.agent.flags.obs.use_tabs = st.checkbox("use_tabs", value=st.session_state.agent.flags.obs.use_tabs) + st.markdown("**Other Flags**") + col1, col2, col3, col4, col5, col6 = st.columns([1, 1, 1, 1, 1, 1]) + with col1: + st.session_state.agent.flags.use_plan = st.checkbox("use_plan", value=st.session_state.agent.flags.use_plan) + st.session_state.agent.flags.use_hints = st.checkbox("use_hints", value=st.session_state.agent.flags.use_hints) + with col2: + st.session_state.agent.flags.use_criticise = st.checkbox("use_criticise", value=st.session_state.agent.flags.use_criticise) + st.session_state.agent.flags.be_cautious = st.checkbox("be_cautious", value=st.session_state.agent.flags.be_cautious) + with col3: + st.session_state.agent.flags.use_thinking = st.checkbox("use_thinking", value=st.session_state.agent.flags.use_thinking) + st.session_state.agent.flags.enable_chat = st.checkbox("enable_chat", value=st.session_state.agent.flags.enable_chat) + with col4: + st.session_state.agent.flags.use_memory = st.checkbox("use_memory", value=st.session_state.agent.flags.use_memory) + with col5: + st.session_state.agent.flags.use_abstract_example = st.checkbox( + "use_abstract_example", value=st.session_state.agent.flags.use_abstract_example + ) + with col6: + st.session_state.agent.flags.use_concrete_example = st.checkbox( + "use_concrete_example", value=st.session_state.agent.flags.use_concrete_example + ) + extra_instructions = st.text_area("extra_instructions", value=st.session_state.agent.flags.extra_instructions) + if extra_instructions == "": + extra_instructions = None + st.session_state.agent.flags.extra_instructions = extra_instructions + + +def undo_last_agent_step(): + st.session_state.agent.obs_history.pop() # remove last observation + st.session_state.agent.actions.pop() # remove last action + st.session_state.agent.thoughts.pop() # remove last thought + st.session_state.agent.memories.pop() # remove last memory + + +def set_controller(): + set_agent_state_box() + set_prompt_modifier() + col_prev, col_redo, col_next = st.columns([1, 1, 1]) + with col_prev: + prev_disabled = len(st.session_state.actions_history) == 0 + if st.button("⬅️ Previous Step", disabled=prev_disabled, use_container_width=True): + if not prev_disabled: + st.session_state.actions_history.pop() + st.session_state.action = None if len(st.session_state.actions_history) == 0 else st.session_state.actions_history[-1] + undo_last_agent_step() + undo_last_agent_step() + restore_environment() + st.rerun() + with col_redo: + if st.button("🔄 Regenerate Action", use_container_width=True): + undo_last_agent_step() + get_action() + st.rerun() + with col_next: + if st.button("➡️ Next Step", use_container_width=True): + st.session_state.actions_history.append(st.session_state.action) + step_environment(st.session_state.action) + st.session_state.action = None + st.rerun() + + +def set_screenshot_tab(): + if isinstance(st.session_state.last_obs, dict): + if st.session_state.last_obs.get("screenshot", None) is not None: + img_arr = st.session_state.last_obs["screenshot"] + if isinstance(img_arr, list): + img_arr = np.array(img_arr) + if isinstance(img_arr, np.ndarray): + im = PIL.Image.fromarray(img_arr) + buffered = BytesIO() + im.save(buffered, format="PNG") + img_b64 = base64.b64encode(buffered.getvalue()).decode() + st.markdown( + f'
', + unsafe_allow_html=True, + ) + + +def set_axtree_tab(): + if isinstance(st.session_state.last_obs, dict): + if st.session_state.last_obs.get("axtree_txt", None) is not None: + st.code(st.session_state.last_obs["axtree_txt"], language=None) + + +def set_prompt_tab(): + if st.session_state.action_info is not None and isinstance(st.session_state.action_info.chat_messages, Discussion): + chat_messages = st.session_state.action_info.chat_messages.messages + new_chat_messages = [] + for message in chat_messages: + if isinstance(message["content"], list): + # concatenate all text elements + new_chat_messages.append( + {"role": message["role"], "content": "\n\n".join([elem["text"] for elem in message["content"] if elem["type"] == "text"])} + ) + else: + new_chat_messages.append(message) + st.code(tokenizer.apply_chat_template(new_chat_messages, add_special_tokens=True, tokenize=False), wrap_lines=True, language="markdown") + + +def set_info_tabs(): + # Display only if everything is now ready + tab1, tab2, tab3 = st.tabs(["Screenshot", "AxTree", "Prompt"]) + + with tab1: + set_screenshot_tab() + with tab2: + set_axtree_tab() + with tab3: + set_prompt_tab() + + +def run_streamlit(): + + # config page + st.set_page_config(page_title="AgentLab Controller", page_icon="🎮", layout="wide", initial_sidebar_state="collapsed") + st.markdown('

🎮 AgentLab Controller 🎮

', unsafe_allow_html=True) + + setup_sidebar() + + set_session_state() + set_task_selector() + + if st.session_state.agent is not None: + if st.session_state.action is None: + get_action() + with st.container(border=True): + set_controller() + set_info_tabs() + + +def main(): + run_streamlit() + + +if __name__ == "__main__": + main() diff --git a/src/agentlab/analyze/run_agentlab_controller.py b/src/agentlab/analyze/run_agentlab_controller.py new file mode 100644 index 00000000..5b472789 --- /dev/null +++ b/src/agentlab/analyze/run_agentlab_controller.py @@ -0,0 +1,14 @@ +from streamlit.web import cli + +from pathlib import Path + +CURR_DIR = Path(__file__).parent +agent_controller_path = CURR_DIR / "agent_controller.py" + + +def main(): + cli.main_run([str(agent_controller_path), "--server.port", "8501"]) + + +if __name__ == "__main__": + main() diff --git a/src/agentlab/analyze/server.py b/src/agentlab/analyze/server.py new file mode 100644 index 00000000..66aa28ae --- /dev/null +++ b/src/agentlab/analyze/server.py @@ -0,0 +1,426 @@ +# server.py +import base64 +import copy +import importlib +import logging +import time +from typing import Any, Dict, Optional + +import dotenv +import numpy as np +import uvicorn + +# Import your BrowserEnv and any task setup you need +from bgym import DEFAULT_BENCHMARKS +from browsergym.core.env import BrowserEnv +from browsergym.core.task import AbstractBrowserTask +from fastapi import FastAPI, Request +from pydantic import BaseModel + +dotenv.load_dotenv() + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +app = FastAPI() + + +# Utils to import the action mapping fn +def import_from_path(path): + """ + Import and instantiate a class, then return its 'to_python_code' method. + For example, given 'browsergym.core.action.highlevel.HighLevelActionSet.to_python_code', + this will instantiate HighLevelActionSet and return its to_python_code method. + """ + import importlib + + parts = path.split(".") + # Find the module (the longest prefix that can be imported) + for i in range(len(parts), 0, -1): + module_name = ".".join(parts[:i]) + try: + module = importlib.import_module(module_name) + break + except ModuleNotFoundError: + continue + else: + raise ModuleNotFoundError(f"Could not import module from path: {path}") + + obj = module + for attr in parts[i:]: + obj = getattr(obj, attr) + + # If the final object is a method, and its __qualname__ contains a class, instantiate the class + if callable(obj) and hasattr(obj, "__qualname__") and "." in obj.__qualname__: + class_name = obj.__qualname__.split(".")[0] + cls = getattr(module, class_name) + instance = cls() + method = getattr(instance, obj.__name__) + return method + + return obj + + +## Utils to convert to safe JSON response +def make_json_safe(obj): + if isinstance(obj, np.ndarray): + # convert to base64 + return {"data": base64.b64encode(obj.tobytes()).decode("utf-8"), "shape": obj.shape, "dtype": str(obj.dtype)} + elif isinstance(obj, dict): + return {k: make_json_safe(v) for k, v in obj.items()} + elif isinstance(obj, (list, tuple)): + return [make_json_safe(v) for v in obj] + elif hasattr(obj, "__dict__"): + return make_json_safe(vars(obj)) + else: + return obj + + +# --- Models for requests --- +class SetInfoRequest(BaseModel): + benchmark_name: str + task_name: str + seed: int + action_mapping_fn: str + exp_dir: str + + +class StepRequest(BaseModel): + action: str + + +# --- Persistent Environment State --- +class EnvWrapper: + def __init__(self): + + # env info + self.benchmark_name = None + self.task_name = None + self.seed = None + self.action_mapping_fn = None + self.exp_dir = None + self.info_set = False + + # env state + self.env = None + self.last_obs = None + self.last_info = None + + def set_info( + self, + benchmark_name: str, + task_name: str, + seed: int, + action_mapping_fn: str, + exp_dir: str, + ): + """Set the environment info. + + :param benchmark_name: Name of the benchmark + :type benchmark_name: str + :param task_name: Name of the task + :type task_name: str + :param seed: Seed of the task. + :type seed: int + :param action_mapping_fn: Action mapping function + :type action_mapping_fn: str + :param exp_dir: Directory for experiment + :type exp_dir: str + :return: Dictionary with status + :rtype: dict + """ + if self.info_set: + return make_json_safe( + { + "status": "error", + "message": "Environment info already set. Please unset the environment info first.", + } + ) + if self.env is not None: + return make_json_safe( + { + "status": "error", + "message": "Environment already created. Close the current environment proceeding.", + } + ) + self.benchmark_name = benchmark_name + self.task_name = task_name + self.seed = seed + self.action_mapping_fn = action_mapping_fn + self.exp_dir = exp_dir + self.info_set = True + + return make_json_safe( + { + "status": "success", + "message": "Environment info set successfully.", + } + ) + + def get_info(self) -> dict: + """Get the environment info + + :return: Dictionary with info + :rtype: dict + """ + if not self.info_set: + return make_json_safe( + { + "status": "error", + "message": "Environment info not set. Please set the environment info first.", + } + ) + return make_json_safe( + { + "status": "success", + "message": "Environment info retrieved successfully.", + "benchmark_name": self.benchmark_name, + "task_name": self.task_name, + "seed": self.seed, + "action_mapping_fn": self.action_mapping_fn, + "exp_dir": self.exp_dir, + } + ) + + def unset_info(self) -> dict: + """Unset the environment info + + :return: Dictionary with status + :rtype: dict + """ + if not self.info_set: + return make_json_safe( + { + "status": "error", + "message": "Environment info not set. Please set the environment info first.", + } + ) + self.info_set = False + self.benchmark_name = None + self.task_name = None + self.seed = None + self.action_mapping_fn = None + self.exp_dir = None + return make_json_safe( + { + "status": "success", + "message": "Environment info unset successfully.", + } + ) + + def status(self) -> dict: + """Get the environment status + + :return: Dictionary with status + :rtype: dict + """ + return make_json_safe( + { + "status": "success", + "message": "Environment status retrieved successfully.", + "info_set": self.info_set, + "env_created": self.env is not None, + } + ) + + def reset(self) -> dict: + """Reset the environment + + :return: Dictionary with obs and info + :rtype: dict + """ + start = time.time() + if not self.info_set: + return make_json_safe( + { + "status": "error", + "message": "Environment info not set. Please set the environment info first.", + } + ) + if self.env is not None: + # close the current environment first + self.env.close() + self.env = None + + # then create the new environment + benchmark = DEFAULT_BENCHMARKS[self.benchmark_name]() + benchmark.env_args_list = [ + elem for elem in benchmark.env_args_list if elem.task_name == self.task_name and str(elem.task_seed) == str(self.seed) + ] + benchmark.prepare_backends() + + env_args = benchmark.env_args_list[0] + # env_args.headless = False + + self.action_mapping = import_from_path(self.action_mapping_fn) + end = time.time() + logger.info(f"init reset done in {end - start}") + start = time.time() + self.env = env_args.make_env(self.action_mapping, self.exp_dir) + end = time.time() + logger.info(f"make_env done in {end - start}") + start = time.time() + # finally, reset the environment + obs, info = self.env.reset(seed=self.seed) + self.last_obs = copy.deepcopy(obs) + self.last_info = copy.deepcopy(info) + end = time.time() + logger.info(f"env reset done in {end - start}") + start = time.time() + # out = make_json_safe( + out = make_json_safe( + { + "status": "success", + "message": "Environment reset successfully", + "obs": self.last_obs, + "info": self.last_info, + } + ) + end = time.time() + logger.info(f"payload cleaned in {end - start}") + # log payload size + from pympler import asizeof + + logger.info(f"Payload size: {asizeof.asizeof(out)} bytes") + # print(out) + # return {"status": "success", "message": "Environment reset successfully"} + return out + + def step(self, action: str) -> dict: + """Step the environment + + :param action: Action to take + :type action: str + :return: Dictionary with obs, reward, terminated, truncated and info + :rtype: dict + """ + if self.env is None: + return make_json_safe( + { + "status": "error", + "message": "Environment not created. Please create an environment first.", + } + ) + start = time.time() + obs, reward, terminated, truncated, info = self.env.step(action) + end = time.time() + logger.info(f"env step done in {end - start}") + start = time.time() + self.last_obs = copy.deepcopy(obs) + self.last_info = copy.deepcopy(info) + out = make_json_safe( + { + "status": "success", + "message": "Environment stepped successfully.", + "obs": obs, + "reward": reward, + "terminated": terminated, + "truncated": truncated, + "info": info, + } + ) + end = time.time() + logger.info(f"obs copied in {end - start}") + return out + + def get_obs(self) -> dict: + """Get the last observation + + :return: Dictionary with obs and info + :rtype: dict + """ + if self.env is None: + return make_json_safe( + { + "status": "error", + "message": "Environment not created. Please create an environment first.", + } + ) + return make_json_safe( + { + "status": "success", + "message": "Observation retrieved successfully.", + "obs": self.last_obs, + "info": self.last_info, + } + ) + + def close(self) -> dict: + """Close the environment + + :return: Dictionary with status + :rtype: dict + """ + if self.env is None: + return make_json_safe( + { + "status": "error", + "message": "Environment not created. Please create an environment first.", + } + ) + self.env.close() + self.env = None + return make_json_safe( + { + "status": "success", + "message": "Environment closed successfully.", + } + ) + + +env = EnvWrapper() + + +# --- FastAPI endpoints --- +@app.post("/set_info") +def set_info(req: SetInfoRequest): + return env.set_info( + benchmark_name=req.benchmark_name, + task_name=req.task_name, + seed=req.seed, + action_mapping_fn=req.action_mapping_fn, + exp_dir=req.exp_dir, + ) + + +@app.get("/get_info") +def get_info(): + return env.get_info() + + +@app.post("/unset_info") +def unset_info(): + return env.unset_info() + + +@app.get("/status") +def status(): + return env.status() + + +@app.post("/reset") +def reset(): + return env.reset() + + +@app.post("/step") +def step(req: StepRequest): + return env.step(action=req.action) + + +@app.get("/get_obs") +def get_obs(): + return env.get_obs() + + +@app.post("/close") +def close(): + return env.close() + + +def main(): + uvicorn.run("agentlab.analyze.server:app", host="127.0.0.1", port=8000, reload=True) + + +if __name__ == "__main__": + main() From a89a37d5de01df32ad39c253d7dae3f1f511f283 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Thu, 12 Jun 2025 10:58:08 -0400 Subject: [PATCH 03/24] add updating action and thought, add reload task instead of reset --- src/agentlab/analyze/agent_controller.py | 65 +++++++++++- src/agentlab/analyze/server.py | 129 ++++++++++++++++++----- 2 files changed, 166 insertions(+), 28 deletions(-) diff --git a/src/agentlab/analyze/agent_controller.py b/src/agentlab/analyze/agent_controller.py index 9e9941f1..4a1a611f 100644 --- a/src/agentlab/analyze/agent_controller.py +++ b/src/agentlab/analyze/agent_controller.py @@ -174,6 +174,7 @@ def set_task_selector(): prepare_agent() set_environment_info() + prepare_benchmark() reset_environment() @@ -216,6 +217,16 @@ def set_environment_info(): logger.info(f"Done in {end - start}") +def prepare_benchmark(): + logger.info("Preparing benchmark...") + start = datetime.now() + resp = requests.post(f"{SERVER_URL}/prepare_benchmark") + if resp.status_code != 200 or resp.json().get("status") != "success": + st.error(resp.json()) + end = datetime.now() + logger.info(f"Done in {end - start}") + + def reset_environment(): logger.info("Restarting environment...") start = datetime.now() @@ -242,6 +253,29 @@ def reset_environment(): logger.info(f"Done postproc in {end - start}") +def reload_task(): + logger.info("Reloading task...") + start = datetime.now() + resp = requests.post(f"{SERVER_URL}/reload_task") + if resp.status_code != 200 or resp.json().get("status") != "success": + print(resp.status_code) + print(resp.json()["status"]) + print(resp.json()["message"]) + response_json = resp.json() + if "obs" in response_json: + if "screenshot" in response_json["obs"]: + screenshot_data = response_json["obs"]["screenshot"] + # convert base64 to numpy array + screenshot = np.frombuffer(base64.b64decode(screenshot_data["data"]), dtype=np.dtype(screenshot_data["dtype"])) + screenshot = screenshot.reshape(screenshot_data["shape"]) + response_json["obs"]["screenshot"] = screenshot + if st.session_state.agent.obs_preprocessor: + response_json["obs"] = st.session_state.agent.obs_preprocessor(response_json["obs"]) + st.session_state.last_obs = response_json["obs"] + end = datetime.now() + logger.info(f"Done in {end - start}") + + def step_environment(action): logger.info("Stepping environment...") start = datetime.now() @@ -269,7 +303,7 @@ def step_environment(action): def restore_environment(): - reset_environment() + reload_task() for action in st.session_state.actions_history: step_environment(action) @@ -285,21 +319,46 @@ def get_action(): def set_agent_state_box(): + + # Custom CSS to set textarea style same as code block + st.markdown( + """ + + """, + unsafe_allow_html=True, + ) + # set agent state and goal box with st.container(): col1, col2, col3 = st.columns([1, 1, 1]) with col1: with st.container(border=True, height=250): st.markdown("**Goal**") + # st.text_area("", st.session_state.agent.obs_history[-1]["goal"], height=175, disabled=True, label_visibility="collapsed") st.code(st.session_state.agent.obs_history[-1]["goal"], wrap_lines=True, language=None, height=175) with col2: with st.container(border=True, height=250): st.markdown("**Think**") - st.code(st.session_state.action_info.think, wrap_lines=True, language=None, height=175) + st.session_state.action_info.think = st.text_area( + "Think", st.session_state.action_info.think, height=172, label_visibility="collapsed" + ) with col3: with st.container(border=True, height=250): st.markdown("**Action**") - st.code(st.session_state.action, wrap_lines=True, language="python", height=175) + st.session_state.action = st.text_area("Action", st.session_state.action, height=172, label_visibility="collapsed") + # st.code(st.session_state.action, wrap_lines=True, language="python", height=175) def set_prompt_modifier(): diff --git a/src/agentlab/analyze/server.py b/src/agentlab/analyze/server.py index 66aa28ae..0ea04a2b 100644 --- a/src/agentlab/analyze/server.py +++ b/src/agentlab/analyze/server.py @@ -106,6 +106,10 @@ def __init__(self): self.last_obs = None self.last_info = None + # used to reload task + self.start_info = None + self.start_url = None + def set_info( self, benchmark_name: str, @@ -223,12 +227,7 @@ def status(self) -> dict: } ) - def reset(self) -> dict: - """Reset the environment - - :return: Dictionary with obs and info - :rtype: dict - """ + def prepare_benchmark(self) -> dict: start = time.time() if not self.info_set: return make_json_safe( @@ -237,38 +236,117 @@ def reset(self) -> dict: "message": "Environment info not set. Please set the environment info first.", } ) + if self.env is not None: # close the current environment first self.env.close() self.env = None - # then create the new environment benchmark = DEFAULT_BENCHMARKS[self.benchmark_name]() benchmark.env_args_list = [ elem for elem in benchmark.env_args_list if elem.task_name == self.task_name and str(elem.task_seed) == str(self.seed) ] + start = time.time() benchmark.prepare_backends() + end = time.time() + logger.info(f"prepare_backends done in {end - start}") env_args = benchmark.env_args_list[0] - # env_args.headless = False - self.action_mapping = import_from_path(self.action_mapping_fn) - end = time.time() - logger.info(f"init reset done in {end - start}") + + # create environment start = time.time() self.env = env_args.make_env(self.action_mapping, self.exp_dir) + print(self.env) end = time.time() logger.info(f"make_env done in {end - start}") + return make_json_safe( + { + "status": "success", + "message": "Environment prepared successfully.", + } + ) + + def reload_task(self) -> dict: + """Reload the task + + :return: Dictionary with status + :rtype: dict + """ + start = time.time() + if not self.info_set: + return make_json_safe( + { + "status": "error", + "message": "Environment info not set. Please set the environment info first.", + } + ) + elif not self.env: + return make_json_safe( + { + "status": "error", + "message": "Environment not created. Please create an environment first.", + } + ) + + tmp_start = time.time() + self.env.unwrapped.page.goto(self.start_url, wait_until="load") + tmp_end = time.time() + logger.info(f"goto done in {tmp_end - tmp_start}") + tmp_start = time.time() + self.env.unwrapped.page.evaluate("window.localStorage.clear(); window.sessionStorage.clear();") + + obs = self.env.unwrapped._get_obs() + tmp_end = time.time() + logger.info(f"clear storage done in {tmp_end - tmp_start}") + + end = time.time() + logger.info(f"reload_task done in {end - start}") + + self.last_obs = copy.deepcopy(obs) + self.last_info = copy.deepcopy(self.start_info) + return make_json_safe( + { + "status": "success", + "message": "Task reloaded successfully.", + "obs": self.last_obs, + "info": self.last_info, + } + ) + + def reset(self) -> dict: + """Reset the environment + + :return: Dictionary with obs and info + :rtype: dict + """ start = time.time() + if not self.info_set: + return make_json_safe( + { + "status": "error", + "message": "Environment info not set. Please set the environment info first.", + } + ) + elif not self.env: + return make_json_safe( + { + "status": "error", + "message": "Environment not created. Please create an environment first.", + } + ) + # finally, reset the environment + start = time.time() obs, info = self.env.reset(seed=self.seed) - self.last_obs = copy.deepcopy(obs) - self.last_info = copy.deepcopy(info) end = time.time() logger.info(f"env reset done in {end - start}") - start = time.time() - # out = make_json_safe( - out = make_json_safe( + + self.last_obs = copy.deepcopy(obs) + self.last_info = copy.deepcopy(info) + self.start_info = copy.deepcopy(info) + self.start_url = copy.deepcopy(self.env.unwrapped.page.url) + return make_json_safe( { "status": "success", "message": "Environment reset successfully", @@ -276,15 +354,6 @@ def reset(self) -> dict: "info": self.last_info, } ) - end = time.time() - logger.info(f"payload cleaned in {end - start}") - # log payload size - from pympler import asizeof - - logger.info(f"Payload size: {asizeof.asizeof(out)} bytes") - # print(out) - # return {"status": "success", "message": "Environment reset successfully"} - return out def step(self, action: str) -> dict: """Step the environment @@ -398,6 +467,16 @@ def status(): return env.status() +@app.post("/prepare_benchmark") +def prepare_benchmark(): + return env.prepare_benchmark() + + +@app.post("/reload_task") +def reload_task(): + return env.reload_task() + + @app.post("/reset") def reset(): return env.reset() From 033e21d1f93c6f40a831bcfe59f2ff6037d28aff Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Thu, 12 Jun 2025 12:30:41 -0400 Subject: [PATCH 04/24] remove deployment name from azure model args --- src/agentlab/llm/llm_configs.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/agentlab/llm/llm_configs.py b/src/agentlab/llm/llm_configs.py index fb461d81..52ecbbe3 100644 --- a/src/agentlab/llm/llm_configs.py +++ b/src/agentlab/llm/llm_configs.py @@ -245,7 +245,6 @@ ### Azure "azure/gpt-4o-mini": AzureModelArgs( model_name="gpt-4o-mini", - # deployment_name="gpt-4o-mini-2024-07-18", max_total_tokens=128_000, max_input_tokens=128_000, max_new_tokens=16_384, @@ -253,7 +252,6 @@ ), "azure/gpt-4o": AzureModelArgs( model_name="gpt-4o", - # deployment_name="gpt-4o-mini-2024-07-18", max_total_tokens=128_000, max_input_tokens=128_000, max_new_tokens=16_384, @@ -261,7 +259,6 @@ ), "azure/gpt-4.1": AzureModelArgs( model_name="gpt-4.1", - # deployment_name="gpt-4.1", max_total_tokens=128_000, max_input_tokens=128_000, max_new_tokens=16_384, @@ -269,7 +266,6 @@ ), "azure/gpt-4.1-mini": AzureModelArgs( model_name="gpt-4.1-mini", - # deployment_name="gpt-4.1-mini", max_total_tokens=128_000, max_input_tokens=128_000, max_new_tokens=16_384, From d65ad94c250a72a92d4087061fed20a3137fea94 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Thu, 12 Jun 2025 12:56:31 -0400 Subject: [PATCH 05/24] add streamlit requirement, update readme --- README.md | 42 ++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 1 + 2 files changed, 43 insertions(+) diff --git a/README.md b/README.md index b5807314..9b8d3233 100644 --- a/README.md +++ b/README.md @@ -226,6 +226,48 @@ image to select a step and observe the action taken by the agent. **⚠️ Note**: Gradio is still developing, and unexpected behavior has been frequently noticed. Version 5.5 seems to work properly so far. If you're not sure that the proper information is displaying, refresh the page and select your experiment again. +### AgentLab Server and AgentLab Controller + +The AgentLab Server and Controller are two components that work together to control and debug an agent deployed in an environment. + +#### Prerequisites + +First, set a `.env` file at the root of the repo with the following content: + +```bash +# LLM Creds (Azure as an example) +AZURE_OPENAI_ENDPOINT= +AZURE_OPENAI_API_KEY= +AZURE_OPENAI_API_VERSION= + +# ServiceNow dev instance creds +SNOW_INSTANCE_URL=https://.service-now.com/ +SNOW_INSTANCE_UNAME="admin" +SNOW_INSTANCE_PWD= + +# MiniWob +MINIWOB_URL="file:///path/to/BrowserGym/miniwob-plusplus/miniwob/html/miniwob/" +``` + +#### Launch the server + +The AgentLab Server is responsible for hosting and enabling interaction with the environment. It is a lightweight FastAPI server that handles the BrowserGym environment and provides a REST API for the controller. + +To launch the server, open a terminal and run (you will need to keep this terminal open for the next step): + +```bash +agentlab-server +``` + +#### Launch the controller + +The AgentLab Controller is a streamlit app responsible for controlling the agent and how it interacts with the environment hosted on the server. + +To launch the controller, open a new terminal and run: + +```bash +agentlab-controller +``` ## 🏆 Leaderboard diff --git a/requirements.txt b/requirements.txt index 6322ffd3..d29159ef 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,3 +27,4 @@ ray[default] python-slugify pillow gymnasium>=0.27 +streamlit From abc5af185c44b6bdb13de31f16a46c51610b0342 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Thu, 12 Jun 2025 12:56:44 -0400 Subject: [PATCH 06/24] clean up controller and server --- src/agentlab/analyze/agent_controller.py | 1 - src/agentlab/analyze/server.py | 71 ++++++++---------------- 2 files changed, 23 insertions(+), 49 deletions(-) diff --git a/src/agentlab/analyze/agent_controller.py b/src/agentlab/analyze/agent_controller.py index 4a1a611f..1e0ed512 100644 --- a/src/agentlab/analyze/agent_controller.py +++ b/src/agentlab/analyze/agent_controller.py @@ -3,7 +3,6 @@ import importlib import logging from io import BytesIO -from pathlib import Path import requests import numpy as np import PIL.Image diff --git a/src/agentlab/analyze/server.py b/src/agentlab/analyze/server.py index 0ea04a2b..53213cb0 100644 --- a/src/agentlab/analyze/server.py +++ b/src/agentlab/analyze/server.py @@ -2,9 +2,6 @@ import base64 import copy import importlib -import logging -import time -from typing import Any, Dict, Optional import dotenv import numpy as np @@ -12,27 +9,20 @@ # Import your BrowserEnv and any task setup you need from bgym import DEFAULT_BENCHMARKS -from browsergym.core.env import BrowserEnv -from browsergym.core.task import AbstractBrowserTask -from fastapi import FastAPI, Request +from fastapi import FastAPI from pydantic import BaseModel dotenv.load_dotenv() -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - app = FastAPI() -# Utils to import the action mapping fn def import_from_path(path): """ - Import and instantiate a class, then return its 'to_python_code' method. - For example, given 'browsergym.core.action.highlevel.HighLevelActionSet.to_python_code', - this will instantiate HighLevelActionSet and return its to_python_code method. + Util function to import and instantiate a class, then return a specific method. + For example, given `browsergym.core.action.highlevel.HighLevelActionSet.to_python_code`, + this will instantiate `HighLevelActionSet` and return its `to_python_code` method. """ - import importlib parts = path.split(".") # Find the module (the longest prefix that can be imported) @@ -61,8 +51,11 @@ def import_from_path(path): return obj -## Utils to convert to safe JSON response def make_json_safe(obj): + """ + Util function to convert numpy arrays and other non-JSON-serializable objects to JSON-serializable objects. + Specifically, we convert numpy arrays to base64 encoded strings so that payloads are of reasonable size. + """ if isinstance(obj, np.ndarray): # convert to base64 return {"data": base64.b64encode(obj.tobytes()).decode("utf-8"), "shape": obj.shape, "dtype": str(obj.dtype)} @@ -228,7 +221,12 @@ def status(self) -> dict: ) def prepare_benchmark(self) -> dict: - start = time.time() + """ + Prepare the benchmark environment. + + :return: Dictionary with status + :rtype: dict + """ if not self.info_set: return make_json_safe( { @@ -241,25 +239,19 @@ def prepare_benchmark(self) -> dict: # close the current environment first self.env.close() self.env = None - # then create the new environment + + # prepare backends benchmark = DEFAULT_BENCHMARKS[self.benchmark_name]() benchmark.env_args_list = [ elem for elem in benchmark.env_args_list if elem.task_name == self.task_name and str(elem.task_seed) == str(self.seed) ] - start = time.time() benchmark.prepare_backends() - end = time.time() - logger.info(f"prepare_backends done in {end - start}") env_args = benchmark.env_args_list[0] self.action_mapping = import_from_path(self.action_mapping_fn) # create environment - start = time.time() self.env = env_args.make_env(self.action_mapping, self.exp_dir) - print(self.env) - end = time.time() - logger.info(f"make_env done in {end - start}") return make_json_safe( { "status": "success", @@ -273,7 +265,6 @@ def reload_task(self) -> dict: :return: Dictionary with status :rtype: dict """ - start = time.time() if not self.info_set: return make_json_safe( { @@ -289,19 +280,12 @@ def reload_task(self) -> dict: } ) - tmp_start = time.time() + # instead of resetting the whole environment, we go back to the original webpage and clear localStorage and sessionStorage + # NOTE: this is not guaranteed to result in the exact same state, but we find that it works most of the time, is much + # faster than resetting the whole environment, and ensures the seed of the environment remains the same self.env.unwrapped.page.goto(self.start_url, wait_until="load") - tmp_end = time.time() - logger.info(f"goto done in {tmp_end - tmp_start}") - tmp_start = time.time() self.env.unwrapped.page.evaluate("window.localStorage.clear(); window.sessionStorage.clear();") - obs = self.env.unwrapped._get_obs() - tmp_end = time.time() - logger.info(f"clear storage done in {tmp_end - tmp_start}") - - end = time.time() - logger.info(f"reload_task done in {end - start}") self.last_obs = copy.deepcopy(obs) self.last_info = copy.deepcopy(self.start_info) @@ -320,7 +304,6 @@ def reset(self) -> dict: :return: Dictionary with obs and info :rtype: dict """ - start = time.time() if not self.info_set: return make_json_safe( { @@ -336,11 +319,8 @@ def reset(self) -> dict: } ) - # finally, reset the environment - start = time.time() + # reset the environment obs, info = self.env.reset(seed=self.seed) - end = time.time() - logger.info(f"env reset done in {end - start}") self.last_obs = copy.deepcopy(obs) self.last_info = copy.deepcopy(info) @@ -370,14 +350,12 @@ def step(self, action: str) -> dict: "message": "Environment not created. Please create an environment first.", } ) - start = time.time() + # step the environment obs, reward, terminated, truncated, info = self.env.step(action) - end = time.time() - logger.info(f"env step done in {end - start}") - start = time.time() + self.last_obs = copy.deepcopy(obs) self.last_info = copy.deepcopy(info) - out = make_json_safe( + return make_json_safe( { "status": "success", "message": "Environment stepped successfully.", @@ -388,9 +366,6 @@ def step(self, action: str) -> dict: "info": info, } ) - end = time.time() - logger.info(f"obs copied in {end - start}") - return out def get_obs(self) -> dict: """Get the last observation From 9665579bc61f6fa8661ad3213e2e9f2c8e860e98 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Thu, 12 Jun 2025 13:00:32 -0400 Subject: [PATCH 07/24] clean up agent controller --- src/agentlab/analyze/agent_controller.py | 57 ++++++++++-------------- 1 file changed, 24 insertions(+), 33 deletions(-) diff --git a/src/agentlab/analyze/agent_controller.py b/src/agentlab/analyze/agent_controller.py index 1e0ed512..511ba0b3 100644 --- a/src/agentlab/analyze/agent_controller.py +++ b/src/agentlab/analyze/agent_controller.py @@ -41,6 +41,17 @@ def get_import_path(obj): return f"{obj.__module__}.{obj.__qualname__}" +def deserialize_response(response_json): + if "obs" in response_json: + if "screenshot" in response_json["obs"]: + screenshot_data = response_json["obs"]["screenshot"] + # convert base64 to numpy array + screenshot = np.frombuffer(base64.b64decode(screenshot_data["data"]), dtype=np.dtype(screenshot_data["dtype"])) + screenshot = screenshot.reshape(screenshot_data["shape"]) + response_json["obs"]["screenshot"] = screenshot + return response_json + + def setup_sidebar(): with st.sidebar: st.markdown( @@ -234,17 +245,11 @@ def reset_environment(): logger.info(f"Done request in {end - start}") start = datetime.now() if resp.status_code != 200 or resp.json().get("status") != "success": - print(resp.status_code) - print(resp.json()["status"]) - print(resp.json()["message"]) + logger.error(resp.status_code) + logger.error(resp.json()["status"]) + logger.error(resp.json()["message"]) response_json = resp.json() - if "obs" in response_json: - if "screenshot" in response_json["obs"]: - screenshot_data = response_json["obs"]["screenshot"] - # convert base64 to numpy array - screenshot = np.frombuffer(base64.b64decode(screenshot_data["data"]), dtype=np.dtype(screenshot_data["dtype"])) - screenshot = screenshot.reshape(screenshot_data["shape"]) - response_json["obs"]["screenshot"] = screenshot + response_json = deserialize_response(response_json) if st.session_state.agent.obs_preprocessor: response_json["obs"] = st.session_state.agent.obs_preprocessor(response_json["obs"]) st.session_state.last_obs = response_json["obs"] @@ -257,17 +262,11 @@ def reload_task(): start = datetime.now() resp = requests.post(f"{SERVER_URL}/reload_task") if resp.status_code != 200 or resp.json().get("status") != "success": - print(resp.status_code) - print(resp.json()["status"]) - print(resp.json()["message"]) + logger.error(resp.status_code) + logger.error(resp.json()["status"]) + logger.error(resp.json()["message"]) response_json = resp.json() - if "obs" in response_json: - if "screenshot" in response_json["obs"]: - screenshot_data = response_json["obs"]["screenshot"] - # convert base64 to numpy array - screenshot = np.frombuffer(base64.b64decode(screenshot_data["data"]), dtype=np.dtype(screenshot_data["dtype"])) - screenshot = screenshot.reshape(screenshot_data["shape"]) - response_json["obs"]["screenshot"] = screenshot + response_json = deserialize_response(response_json) if st.session_state.agent.obs_preprocessor: response_json["obs"] = st.session_state.agent.obs_preprocessor(response_json["obs"]) st.session_state.last_obs = response_json["obs"] @@ -281,17 +280,12 @@ def step_environment(action): payload = {"action": action} resp = requests.post(f"{SERVER_URL}/step", json=payload) if resp.status_code != 200 or resp.json().get("status") != "success": - print(resp.status_code) - print(resp.json()["status"]) - print(resp.json()["message"]) + logger.error(resp.status_code) + logger.error(resp.json()["status"]) + logger.error(resp.json()["message"]) response_json = resp.json() - if "obs" in response_json: - if "screenshot" in response_json["obs"]: - screenshot_data = response_json["obs"]["screenshot"] - # convert base64 to numpy array - screenshot = np.frombuffer(base64.b64decode(screenshot_data["data"]), dtype=np.dtype(screenshot_data["dtype"])) - screenshot = screenshot.reshape(screenshot_data["shape"]) - response_json["obs"]["screenshot"] = screenshot + response_json = deserialize_response(response_json) + if st.session_state.agent.obs_preprocessor: response_json["obs"] = st.session_state.agent.obs_preprocessor(response_json["obs"]) st.session_state.last_obs = response_json["obs"] @@ -345,7 +339,6 @@ def set_agent_state_box(): with col1: with st.container(border=True, height=250): st.markdown("**Goal**") - # st.text_area("", st.session_state.agent.obs_history[-1]["goal"], height=175, disabled=True, label_visibility="collapsed") st.code(st.session_state.agent.obs_history[-1]["goal"], wrap_lines=True, language=None, height=175) with col2: with st.container(border=True, height=250): @@ -357,12 +350,10 @@ def set_agent_state_box(): with st.container(border=True, height=250): st.markdown("**Action**") st.session_state.action = st.text_area("Action", st.session_state.action, height=172, label_visibility="collapsed") - # st.code(st.session_state.action, wrap_lines=True, language="python", height=175) def set_prompt_modifier(): with st.expander("**Prompt Modifier**", expanded=False): - # st.write(st.session_state.agent.flags) st.markdown("**Observation Flags**") col1, col2, col3, col4, col5, col6 = st.columns([1, 1, 1, 1, 1, 1]) with col1: From b210b997e370d137de4c4c1a447aca2ebf429a54 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Thu, 12 Jun 2025 13:54:46 -0400 Subject: [PATCH 08/24] Add demo video for AgentLab Controller --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 9b8d3233..b92d4753 100644 --- a/README.md +++ b/README.md @@ -228,6 +228,8 @@ image to select a step and observe the action taken by the agent. ### AgentLab Server and AgentLab Controller +https://github.com/user-attachments/assets/9a498c99-453a-4d7c-89fc-13e18db8dad6 + The AgentLab Server and Controller are two components that work together to control and debug an agent deployed in an environment. #### Prerequisites From 257c22e75d3c5cfb3789bd4ac21976c1dfe6d880 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Thu, 12 Jun 2025 14:05:15 -0400 Subject: [PATCH 09/24] add docstrings for server.py --- src/agentlab/analyze/server.py | 98 ++++++++++++++++++++++++++++++---- 1 file changed, 87 insertions(+), 11 deletions(-) diff --git a/src/agentlab/analyze/server.py b/src/agentlab/analyze/server.py index 53213cb0..ca61e039 100644 --- a/src/agentlab/analyze/server.py +++ b/src/agentlab/analyze/server.py @@ -2,6 +2,7 @@ import base64 import copy import importlib +from typing import Any import dotenv import numpy as np @@ -17,11 +18,17 @@ app = FastAPI() -def import_from_path(path): +def import_from_path(path: str) -> callable: """ Util function to import and instantiate a class, then return a specific method. For example, given `browsergym.core.action.highlevel.HighLevelActionSet.to_python_code`, this will instantiate `HighLevelActionSet` and return its `to_python_code` method. + + :param path: Path to the method + :type path: str + :raises ModuleNotFoundError: If the module cannot be imported + :return: The method + :rtype: callable """ parts = path.split(".") @@ -51,10 +58,15 @@ def import_from_path(path): return obj -def make_json_safe(obj): +def make_json_safe(obj: Any) -> Any: """ Util function to convert numpy arrays and other non-JSON-serializable objects to JSON-serializable objects. Specifically, we convert numpy arrays to base64 encoded strings so that payloads are of reasonable size. + + :param obj: Object to convert + :type obj: Any + :return: JSON-serializable object + :rtype: Any """ if isinstance(obj, np.ndarray): # convert to base64 @@ -418,6 +430,14 @@ def close(self) -> dict: # --- FastAPI endpoints --- @app.post("/set_info") def set_info(req: SetInfoRequest): + """ + Set the environment info. + + :param req: Request containing environment info + :type req: SetInfoRequest + :return: Dictionary with status + :rtype: dict + """ return env.set_info( benchmark_name=req.benchmark_name, task_name=req.task_name, @@ -428,47 +448,103 @@ def set_info(req: SetInfoRequest): @app.get("/get_info") -def get_info(): +def get_info() -> dict: + """ + Get the environment info. + + :return: Dictionary with info + :rtype: dict + """ return env.get_info() @app.post("/unset_info") -def unset_info(): +def unset_info() -> dict: + """ + Unset the environment info. + + :return: Dictionary with status + :rtype: dict + """ return env.unset_info() @app.get("/status") -def status(): +def status() -> dict: + """ + Get the status of the environment. + + :return: Dictionary with status + :rtype: dict + """ return env.status() @app.post("/prepare_benchmark") -def prepare_benchmark(): +def prepare_benchmark() -> dict: + """ + Prepare the benchmark. + + :return: Dictionary with status + :rtype: dict + """ return env.prepare_benchmark() @app.post("/reload_task") -def reload_task(): +def reload_task() -> dict: + """ + Reload the task. + + :return: Dictionary with status + :rtype: dict + """ return env.reload_task() @app.post("/reset") -def reset(): +def reset() -> dict: + """ + Reset the environment. + + :return: Dictionary with status + :rtype: dict + """ return env.reset() @app.post("/step") -def step(req: StepRequest): +def step(req: StepRequest) -> dict: + """ + Step the environment. + + :param req: Request containing action + :type req: StepRequest + :return: Dictionary with obs, reward, terminated, truncated and info + :rtype: dict + """ return env.step(action=req.action) @app.get("/get_obs") -def get_obs(): +def get_obs() -> dict: + """ + Get the last observation. + + :return: Dictionary with obs and info + :rtype: dict + """ return env.get_obs() @app.post("/close") -def close(): +def close() -> dict: + """ + Close the environment. + + :return: Dictionary with status + :rtype: dict + """ return env.close() From a51ea1b7f568c58afe770c5a0248b9c046dcf918 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Thu, 12 Jun 2025 14:12:25 -0400 Subject: [PATCH 10/24] change docstring style to google --- src/agentlab/analyze/server.py | 171 ++++++++++++++++++--------------- 1 file changed, 91 insertions(+), 80 deletions(-) diff --git a/src/agentlab/analyze/server.py b/src/agentlab/analyze/server.py index ca61e039..3c05b10f 100644 --- a/src/agentlab/analyze/server.py +++ b/src/agentlab/analyze/server.py @@ -21,14 +21,15 @@ def import_from_path(path: str) -> callable: """ Util function to import and instantiate a class, then return a specific method. - For example, given `browsergym.core.action.highlevel.HighLevelActionSet.to_python_code`, - this will instantiate `HighLevelActionSet` and return its `to_python_code` method. - - :param path: Path to the method - :type path: str - :raises ModuleNotFoundError: If the module cannot be imported - :return: The method - :rtype: callable + + Args: + path (str): Path to the method, e.g., 'browsergym.core.action.highlevel.HighLevelActionSet.to_python_code'. + + Raises: + ModuleNotFoundError: If the module cannot be imported. + + Returns: + callable: The method. """ parts = path.split(".") @@ -63,10 +64,11 @@ def make_json_safe(obj: Any) -> Any: Util function to convert numpy arrays and other non-JSON-serializable objects to JSON-serializable objects. Specifically, we convert numpy arrays to base64 encoded strings so that payloads are of reasonable size. - :param obj: Object to convert - :type obj: Any - :return: JSON-serializable object - :rtype: Any + Args: + obj (Any): Object to convert + + Returns: + Any: JSON-serializable object """ if isinstance(obj, np.ndarray): # convert to base64 @@ -122,21 +124,19 @@ def set_info( seed: int, action_mapping_fn: str, exp_dir: str, - ): - """Set the environment info. - - :param benchmark_name: Name of the benchmark - :type benchmark_name: str - :param task_name: Name of the task - :type task_name: str - :param seed: Seed of the task. - :type seed: int - :param action_mapping_fn: Action mapping function - :type action_mapping_fn: str - :param exp_dir: Directory for experiment - :type exp_dir: str - :return: Dictionary with status - :rtype: dict + ) -> dict: + """ + Set the environment info. + + Args: + benchmark_name (str): Name of the benchmark + task_name (str): Name of the task + seed (int): Seed of the task + action_mapping_fn (str): Action mapping function + exp_dir (str): Directory for experiment + + Returns: + dict: Dictionary with status """ if self.info_set: return make_json_safe( @@ -167,10 +167,11 @@ def set_info( ) def get_info(self) -> dict: - """Get the environment info + """ + Get the environment info. - :return: Dictionary with info - :rtype: dict + Returns: + dict: Dictionary with info """ if not self.info_set: return make_json_safe( @@ -192,10 +193,11 @@ def get_info(self) -> dict: ) def unset_info(self) -> dict: - """Unset the environment info + """ + Unset the environment info. - :return: Dictionary with status - :rtype: dict + Returns: + dict: Dictionary with status """ if not self.info_set: return make_json_safe( @@ -218,10 +220,11 @@ def unset_info(self) -> dict: ) def status(self) -> dict: - """Get the environment status + """ + Get the environment status. - :return: Dictionary with status - :rtype: dict + Returns: + dict: Dictionary with status """ return make_json_safe( { @@ -236,8 +239,8 @@ def prepare_benchmark(self) -> dict: """ Prepare the benchmark environment. - :return: Dictionary with status - :rtype: dict + Returns: + dict: Dictionary with status """ if not self.info_set: return make_json_safe( @@ -272,10 +275,11 @@ def prepare_benchmark(self) -> dict: ) def reload_task(self) -> dict: - """Reload the task + """ + Reload the task. - :return: Dictionary with status - :rtype: dict + Returns: + dict: Dictionary with status """ if not self.info_set: return make_json_safe( @@ -311,10 +315,11 @@ def reload_task(self) -> dict: ) def reset(self) -> dict: - """Reset the environment + """ + Reset the environment. - :return: Dictionary with obs and info - :rtype: dict + Returns: + dict: Dictionary with obs and info """ if not self.info_set: return make_json_safe( @@ -348,12 +353,14 @@ def reset(self) -> dict: ) def step(self, action: str) -> dict: - """Step the environment + """ + Step the environment. + + Args: + action (str): Action to take - :param action: Action to take - :type action: str - :return: Dictionary with obs, reward, terminated, truncated and info - :rtype: dict + Returns: + dict: Dictionary with obs, reward, terminated, truncated and info """ if self.env is None: return make_json_safe( @@ -380,10 +387,11 @@ def step(self, action: str) -> dict: ) def get_obs(self) -> dict: - """Get the last observation + """ + Get the last observation. - :return: Dictionary with obs and info - :rtype: dict + Returns: + dict: Dictionary with obs and info """ if self.env is None: return make_json_safe( @@ -402,10 +410,11 @@ def get_obs(self) -> dict: ) def close(self) -> dict: - """Close the environment + """ + Close the environment. - :return: Dictionary with status - :rtype: dict + Returns: + dict: Dictionary with status """ if self.env is None: return make_json_safe( @@ -429,14 +438,15 @@ def close(self) -> dict: # --- FastAPI endpoints --- @app.post("/set_info") -def set_info(req: SetInfoRequest): +def set_info(req: SetInfoRequest) -> dict: """ Set the environment info. - :param req: Request containing environment info - :type req: SetInfoRequest - :return: Dictionary with status - :rtype: dict + Args: + req (SetInfoRequest): Request containing environment info + + Returns: + dict: Dictionary with status """ return env.set_info( benchmark_name=req.benchmark_name, @@ -452,8 +462,8 @@ def get_info() -> dict: """ Get the environment info. - :return: Dictionary with info - :rtype: dict + Returns: + dict: Dictionary with info """ return env.get_info() @@ -463,8 +473,8 @@ def unset_info() -> dict: """ Unset the environment info. - :return: Dictionary with status - :rtype: dict + Returns: + dict: Dictionary with status """ return env.unset_info() @@ -474,8 +484,8 @@ def status() -> dict: """ Get the status of the environment. - :return: Dictionary with status - :rtype: dict + Returns: + dict: Dictionary with status """ return env.status() @@ -485,8 +495,8 @@ def prepare_benchmark() -> dict: """ Prepare the benchmark. - :return: Dictionary with status - :rtype: dict + Returns: + dict: Dictionary with status """ return env.prepare_benchmark() @@ -496,8 +506,8 @@ def reload_task() -> dict: """ Reload the task. - :return: Dictionary with status - :rtype: dict + Returns: + dict: Dictionary with status """ return env.reload_task() @@ -507,8 +517,8 @@ def reset() -> dict: """ Reset the environment. - :return: Dictionary with status - :rtype: dict + Returns: + dict: Dictionary with status """ return env.reset() @@ -518,10 +528,11 @@ def step(req: StepRequest) -> dict: """ Step the environment. - :param req: Request containing action - :type req: StepRequest - :return: Dictionary with obs, reward, terminated, truncated and info - :rtype: dict + Args: + req (StepRequest): Request containing action + + Returns: + dict: Dictionary with obs, reward, terminated, truncated and info """ return env.step(action=req.action) @@ -531,8 +542,8 @@ def get_obs() -> dict: """ Get the last observation. - :return: Dictionary with obs and info - :rtype: dict + Returns: + dict: Dictionary with obs and info """ return env.get_obs() @@ -542,8 +553,8 @@ def close() -> dict: """ Close the environment. - :return: Dictionary with status - :rtype: dict + Returns: + dict: Dictionary with status """ return env.close() From 0f678578431b2d1d7315d01660f19a701c01685a Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Thu, 12 Jun 2025 14:21:11 -0400 Subject: [PATCH 11/24] format with black line length 100 --- src/agentlab/analyze/agent_controller.py | 131 +++++++++++++++++------ src/agentlab/analyze/server.py | 14 ++- 2 files changed, 112 insertions(+), 33 deletions(-) diff --git a/src/agentlab/analyze/agent_controller.py b/src/agentlab/analyze/agent_controller.py index 511ba0b3..c43cfe38 100644 --- a/src/agentlab/analyze/agent_controller.py +++ b/src/agentlab/analyze/agent_controller.py @@ -2,18 +2,19 @@ import copy import importlib import logging +from datetime import datetime from io import BytesIO -import requests + import numpy as np import PIL.Image +import requests import streamlit as st from agentlab.agents.generic_agent import __all__ as ALL_AGENTS from agentlab.experiments.exp_utils import RESULTS_DIR +from agentlab.llm.llm_utils import Discussion from bgym import DEFAULT_BENCHMARKS from dotenv import load_dotenv -from agentlab.llm.llm_utils import Discussion from transformers import AutoTokenizer -from datetime import datetime # used to display prompt. simple chat template from apache 2.0 model tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") @@ -46,7 +47,9 @@ def deserialize_response(response_json): if "screenshot" in response_json["obs"]: screenshot_data = response_json["obs"]["screenshot"] # convert base64 to numpy array - screenshot = np.frombuffer(base64.b64decode(screenshot_data["data"]), dtype=np.dtype(screenshot_data["dtype"])) + screenshot = np.frombuffer( + base64.b64decode(screenshot_data["data"]), dtype=np.dtype(screenshot_data["dtype"]) + ) screenshot = screenshot.reshape(screenshot_data["shape"]) response_json["obs"]["screenshot"] = screenshot return response_json @@ -132,7 +135,9 @@ def select_agent(): def select_benchmark() -> str: """Dropdown to select a benchmark.""" all_benchmarks = list(DEFAULT_BENCHMARKS.keys()) - benchmark_str = st.selectbox("Select Benchmark", all_benchmarks, index=all_benchmarks.index(DEFAULT_BENCHMARK)) + benchmark_str = st.selectbox( + "Select Benchmark", all_benchmarks, index=all_benchmarks.index(DEFAULT_BENCHMARK) + ) return benchmark_str @@ -145,7 +150,9 @@ def select_task(benchmark): def select_subtask(benchmark, task_str) -> str: """Dropdown to select a subtask based on the task name.""" - all_subtasks = sorted([str(elem.task_seed) for elem in benchmark.env_args_list if elem.task_name == task_str]) + all_subtasks = sorted( + [str(elem.task_seed) for elem in benchmark.env_args_list if elem.task_name == task_str] + ) subtask_str = st.selectbox("Select Subtask", all_subtasks) return subtask_str @@ -153,7 +160,9 @@ def select_subtask(benchmark, task_str) -> str: def set_task_selector(): """Create task selector form. Allows the user to select the agent, benchmark, task, and subtask to run.""" with st.form("Task Selector"): - col1, col2, col3, col4, col5, col6 = st.columns([2, 2, 4, 2, 1, 1], vertical_alignment="bottom") + col1, col2, col3, col4, col5, col6 = st.columns( + [2, 2, 4, 2, 1, 1], vertical_alignment="bottom" + ) with col1: selected_agent_args = select_agent() with col2: @@ -339,17 +348,27 @@ def set_agent_state_box(): with col1: with st.container(border=True, height=250): st.markdown("**Goal**") - st.code(st.session_state.agent.obs_history[-1]["goal"], wrap_lines=True, language=None, height=175) + st.code( + st.session_state.agent.obs_history[-1]["goal"], + wrap_lines=True, + language=None, + height=175, + ) with col2: with st.container(border=True, height=250): st.markdown("**Think**") st.session_state.action_info.think = st.text_area( - "Think", st.session_state.action_info.think, height=172, label_visibility="collapsed" + "Think", + st.session_state.action_info.think, + height=172, + label_visibility="collapsed", ) with col3: with st.container(border=True, height=250): st.markdown("**Action**") - st.session_state.action = st.text_area("Action", st.session_state.action, height=172, label_visibility="collapsed") + st.session_state.action = st.text_area( + "Action", st.session_state.action, height=172, label_visibility="collapsed" + ) def set_prompt_modifier(): @@ -357,12 +376,16 @@ def set_prompt_modifier(): st.markdown("**Observation Flags**") col1, col2, col3, col4, col5, col6 = st.columns([1, 1, 1, 1, 1, 1]) with col1: - st.session_state.agent.flags.obs.use_html = st.checkbox("use_html", value=st.session_state.agent.flags.obs.use_html) + st.session_state.agent.flags.obs.use_html = st.checkbox( + "use_html", value=st.session_state.agent.flags.obs.use_html + ) st.session_state.agent.flags.obs.use_action_history = st.checkbox( "use_action_history", value=st.session_state.agent.flags.obs.use_action_history ) with col2: - st.session_state.agent.flags.obs.use_ax_tree = st.checkbox("use_ax_tree", value=st.session_state.agent.flags.obs.use_ax_tree) + st.session_state.agent.flags.obs.use_ax_tree = st.checkbox( + "use_ax_tree", value=st.session_state.agent.flags.obs.use_ax_tree + ) st.session_state.agent.flags.obs.use_think_history = st.checkbox( "use_think_history", value=st.session_state.agent.flags.obs.use_think_history ) @@ -370,7 +393,9 @@ def set_prompt_modifier(): st.session_state.agent.flags.obs.use_focused_element = st.checkbox( "use_focused_element", value=st.session_state.agent.flags.obs.use_focused_element ) - st.session_state.agent.flags.obs.use_diff = st.checkbox("use_diff", value=st.session_state.agent.flags.obs.use_diff) + st.session_state.agent.flags.obs.use_diff = st.checkbox( + "use_diff", value=st.session_state.agent.flags.obs.use_diff + ) with col4: st.session_state.agent.flags.obs.use_error_logs = st.checkbox( "use_error_logs", value=st.session_state.agent.flags.obs.use_error_logs @@ -379,26 +404,46 @@ def set_prompt_modifier(): "use_screenshot", value=st.session_state.agent.flags.obs.use_screenshot ) with col5: - st.session_state.agent.flags.obs.use_history = st.checkbox("use_history", value=st.session_state.agent.flags.obs.use_history) - st.session_state.agent.flags.obs.use_som = st.checkbox("use_som", value=st.session_state.agent.flags.obs.use_som) + st.session_state.agent.flags.obs.use_history = st.checkbox( + "use_history", value=st.session_state.agent.flags.obs.use_history + ) + st.session_state.agent.flags.obs.use_som = st.checkbox( + "use_som", value=st.session_state.agent.flags.obs.use_som + ) with col6: st.session_state.agent.flags.obs.use_past_error_logs = st.checkbox( "use_past_error_logs", value=st.session_state.agent.flags.obs.use_past_error_logs ) - st.session_state.agent.flags.obs.use_tabs = st.checkbox("use_tabs", value=st.session_state.agent.flags.obs.use_tabs) + st.session_state.agent.flags.obs.use_tabs = st.checkbox( + "use_tabs", value=st.session_state.agent.flags.obs.use_tabs + ) st.markdown("**Other Flags**") col1, col2, col3, col4, col5, col6 = st.columns([1, 1, 1, 1, 1, 1]) with col1: - st.session_state.agent.flags.use_plan = st.checkbox("use_plan", value=st.session_state.agent.flags.use_plan) - st.session_state.agent.flags.use_hints = st.checkbox("use_hints", value=st.session_state.agent.flags.use_hints) + st.session_state.agent.flags.use_plan = st.checkbox( + "use_plan", value=st.session_state.agent.flags.use_plan + ) + st.session_state.agent.flags.use_hints = st.checkbox( + "use_hints", value=st.session_state.agent.flags.use_hints + ) with col2: - st.session_state.agent.flags.use_criticise = st.checkbox("use_criticise", value=st.session_state.agent.flags.use_criticise) - st.session_state.agent.flags.be_cautious = st.checkbox("be_cautious", value=st.session_state.agent.flags.be_cautious) + st.session_state.agent.flags.use_criticise = st.checkbox( + "use_criticise", value=st.session_state.agent.flags.use_criticise + ) + st.session_state.agent.flags.be_cautious = st.checkbox( + "be_cautious", value=st.session_state.agent.flags.be_cautious + ) with col3: - st.session_state.agent.flags.use_thinking = st.checkbox("use_thinking", value=st.session_state.agent.flags.use_thinking) - st.session_state.agent.flags.enable_chat = st.checkbox("enable_chat", value=st.session_state.agent.flags.enable_chat) + st.session_state.agent.flags.use_thinking = st.checkbox( + "use_thinking", value=st.session_state.agent.flags.use_thinking + ) + st.session_state.agent.flags.enable_chat = st.checkbox( + "enable_chat", value=st.session_state.agent.flags.enable_chat + ) with col4: - st.session_state.agent.flags.use_memory = st.checkbox("use_memory", value=st.session_state.agent.flags.use_memory) + st.session_state.agent.flags.use_memory = st.checkbox( + "use_memory", value=st.session_state.agent.flags.use_memory + ) with col5: st.session_state.agent.flags.use_abstract_example = st.checkbox( "use_abstract_example", value=st.session_state.agent.flags.use_abstract_example @@ -407,7 +452,9 @@ def set_prompt_modifier(): st.session_state.agent.flags.use_concrete_example = st.checkbox( "use_concrete_example", value=st.session_state.agent.flags.use_concrete_example ) - extra_instructions = st.text_area("extra_instructions", value=st.session_state.agent.flags.extra_instructions) + extra_instructions = st.text_area( + "extra_instructions", value=st.session_state.agent.flags.extra_instructions + ) if extra_instructions == "": extra_instructions = None st.session_state.agent.flags.extra_instructions = extra_instructions @@ -429,7 +476,11 @@ def set_controller(): if st.button("⬅️ Previous Step", disabled=prev_disabled, use_container_width=True): if not prev_disabled: st.session_state.actions_history.pop() - st.session_state.action = None if len(st.session_state.actions_history) == 0 else st.session_state.actions_history[-1] + st.session_state.action = ( + None + if len(st.session_state.actions_history) == 0 + else st.session_state.actions_history[-1] + ) undo_last_agent_step() undo_last_agent_step() restore_environment() @@ -471,18 +522,31 @@ def set_axtree_tab(): def set_prompt_tab(): - if st.session_state.action_info is not None and isinstance(st.session_state.action_info.chat_messages, Discussion): + if st.session_state.action_info is not None and isinstance( + st.session_state.action_info.chat_messages, Discussion + ): chat_messages = st.session_state.action_info.chat_messages.messages new_chat_messages = [] for message in chat_messages: if isinstance(message["content"], list): # concatenate all text elements new_chat_messages.append( - {"role": message["role"], "content": "\n\n".join([elem["text"] for elem in message["content"] if elem["type"] == "text"])} + { + "role": message["role"], + "content": "\n\n".join( + [elem["text"] for elem in message["content"] if elem["type"] == "text"] + ), + } ) else: new_chat_messages.append(message) - st.code(tokenizer.apply_chat_template(new_chat_messages, add_special_tokens=True, tokenize=False), wrap_lines=True, language="markdown") + st.code( + tokenizer.apply_chat_template( + new_chat_messages, add_special_tokens=True, tokenize=False + ), + wrap_lines=True, + language="markdown", + ) def set_info_tabs(): @@ -500,8 +564,15 @@ def set_info_tabs(): def run_streamlit(): # config page - st.set_page_config(page_title="AgentLab Controller", page_icon="🎮", layout="wide", initial_sidebar_state="collapsed") - st.markdown('

🎮 AgentLab Controller 🎮

', unsafe_allow_html=True) + st.set_page_config( + page_title="AgentLab Controller", + page_icon="🎮", + layout="wide", + initial_sidebar_state="collapsed", + ) + st.markdown( + '

🎮 AgentLab Controller 🎮

', unsafe_allow_html=True + ) setup_sidebar() diff --git a/src/agentlab/analyze/server.py b/src/agentlab/analyze/server.py index 3c05b10f..df3bfbdd 100644 --- a/src/agentlab/analyze/server.py +++ b/src/agentlab/analyze/server.py @@ -72,7 +72,11 @@ def make_json_safe(obj: Any) -> Any: """ if isinstance(obj, np.ndarray): # convert to base64 - return {"data": base64.b64encode(obj.tobytes()).decode("utf-8"), "shape": obj.shape, "dtype": str(obj.dtype)} + return { + "data": base64.b64encode(obj.tobytes()).decode("utf-8"), + "shape": obj.shape, + "dtype": str(obj.dtype), + } elif isinstance(obj, dict): return {k: make_json_safe(v) for k, v in obj.items()} elif isinstance(obj, (list, tuple)): @@ -258,7 +262,9 @@ def prepare_benchmark(self) -> dict: # prepare backends benchmark = DEFAULT_BENCHMARKS[self.benchmark_name]() benchmark.env_args_list = [ - elem for elem in benchmark.env_args_list if elem.task_name == self.task_name and str(elem.task_seed) == str(self.seed) + elem + for elem in benchmark.env_args_list + if elem.task_name == self.task_name and str(elem.task_seed) == str(self.seed) ] benchmark.prepare_backends() @@ -300,7 +306,9 @@ def reload_task(self) -> dict: # NOTE: this is not guaranteed to result in the exact same state, but we find that it works most of the time, is much # faster than resetting the whole environment, and ensures the seed of the environment remains the same self.env.unwrapped.page.goto(self.start_url, wait_until="load") - self.env.unwrapped.page.evaluate("window.localStorage.clear(); window.sessionStorage.clear();") + self.env.unwrapped.page.evaluate( + "window.localStorage.clear(); window.sessionStorage.clear();" + ) obs = self.env.unwrapped._get_obs() self.last_obs = copy.deepcopy(obs) From 46f91d59b4eaf3248214da30007c9bd6ddcc43ea Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Thu, 12 Jun 2025 15:15:39 -0400 Subject: [PATCH 12/24] remove forced background color which was breaking dark mode --- src/agentlab/analyze/agent_controller.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/agentlab/analyze/agent_controller.py b/src/agentlab/analyze/agent_controller.py index c43cfe38..5e6a497c 100644 --- a/src/agentlab/analyze/agent_controller.py +++ b/src/agentlab/analyze/agent_controller.py @@ -334,8 +334,6 @@ def set_agent_state_box(): font-style: normal; line-height: 1.6 !important; padding-top: 18px !important; - background-color: #F8F9FB !important; - } """, From 3e629ac5afa152b7b16803200289149929843cc4 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Wed, 18 Jun 2025 21:22:33 -0400 Subject: [PATCH 13/24] enable looking at all past steps in new tab --- src/agentlab/analyze/agent_controller.py | 308 +++++++++++++++-------- 1 file changed, 200 insertions(+), 108 deletions(-) diff --git a/src/agentlab/analyze/agent_controller.py b/src/agentlab/analyze/agent_controller.py index 5e6a497c..50c42ce8 100644 --- a/src/agentlab/analyze/agent_controller.py +++ b/src/agentlab/analyze/agent_controller.py @@ -29,6 +29,16 @@ SERVER_URL = "http://127.0.0.1:8000" +class Constants: + STATUS = "status" + STATUS_SUCCESS = "success" + STATUS_ERROR = "error" + + OBS = "obs" + SCREENSHOT = "screenshot" + AXTREE_TXT = "axtree_txt" + + class IgnoreMessageFilter(logging.Filter): def filter(self, record): return "but it does not exist!" not in record.getMessage() @@ -43,18 +53,103 @@ def get_import_path(obj): def deserialize_response(response_json): - if "obs" in response_json: - if "screenshot" in response_json["obs"]: - screenshot_data = response_json["obs"]["screenshot"] + if Constants.OBS in response_json: + if Constants.SCREENSHOT in response_json[Constants.OBS]: + screenshot_data = response_json[Constants.OBS][Constants.SCREENSHOT] # convert base64 to numpy array screenshot = np.frombuffer( base64.b64decode(screenshot_data["data"]), dtype=np.dtype(screenshot_data["dtype"]) ) screenshot = screenshot.reshape(screenshot_data["shape"]) - response_json["obs"]["screenshot"] = screenshot + response_json[Constants.OBS][Constants.SCREENSHOT] = screenshot return response_json +def reset_env_history(): + st.session_state.last_obs = None + st.session_state.obs_history = [] + st.session_state.screenshot_history = [] + st.session_state.axtree_history = [] + + +def reset_agent_history(): + st.session_state.action = None + st.session_state.action_info = None + st.session_state.action_history = [] + st.session_state.action_info_history = [] + st.session_state.thought_history = [] + st.session_state.prompt_history = [] + + +def reset_agent_state(): + st.session_state.agent.reset() + + +def step_env_history(obs): + st.session_state.last_obs = copy.deepcopy(obs) + st.session_state.obs_history.append(obs) + st.session_state.screenshot_history.append(obs[Constants.SCREENSHOT]) + st.session_state.axtree_history.append(obs[Constants.AXTREE_TXT]) + + +def step_agent_history(action, action_info): + st.session_state.action = copy.deepcopy(action) + st.session_state.action_info = copy.deepcopy(action_info) + st.session_state.action_history.append(action) + st.session_state.action_info_history.append(action_info) + st.session_state.thought_history.append(action_info.think) + st.session_state.prompt_history.append(get_prompt(action_info)) + + +def set_agent_state(): + st.session_state.agent.obs_history = st.session_state.obs_history + st.session_state.agent.actions = st.session_state.action_history + st.session_state.agent.thoughts = st.session_state.thought_history + + +def revert_env_history(): + st.session_state.obs_history.pop() + st.session_state.screenshot_history.pop() + st.session_state.axtree_history.pop() + + +def revert_agent_history(): + st.session_state.action_history.pop() + st.session_state.action_info_history.pop() + st.session_state.thought_history.pop() + st.session_state.prompt_history.pop() + + +def revert_agent_state(): + st.session_state.agent.obs_history.pop() + st.session_state.agent.actions.pop() + st.session_state.agent.thoughts.pop() + st.session_state.agent.memories.pop() + + +def get_prompt(info): + if info is not None and isinstance(info.chat_messages, Discussion): + chat_messages = info.chat_messages.messages + new_chat_messages = [] + for message in chat_messages: + if isinstance(message["content"], list): + # concatenate all text elements + new_chat_messages.append( + { + "role": message["role"], + "content": "\n\n".join( + [elem["text"] for elem in message["content"] if elem["type"] == "text"] + ), + } + ) + else: + new_chat_messages.append(message) + prompt = tokenizer.apply_chat_template( + new_chat_messages, add_special_tokens=True, tokenize=False + ) + return prompt + + def setup_sidebar(): with st.sidebar: st.markdown( @@ -105,16 +200,31 @@ def set_session_state(): if "subtask" not in st.session_state: st.session_state.subtask = None + # current state if "agent" not in st.session_state: st.session_state.agent = None - if "environment" not in st.session_state: - st.session_state.environment = None if "action" not in st.session_state: st.session_state.action = None if "action_info" not in st.session_state: st.session_state.action_info = None - if "actions_history" not in st.session_state: - st.session_state.actions_history = None + if "last_obs" not in st.session_state: + st.session_state.last_obs = None + + # track history + if "prompt_history" not in st.session_state: + st.session_state.prompt_history = None + if "screenshot_history" not in st.session_state: + st.session_state.screenshot_history = None + if "axtree_history" not in st.session_state: + st.session_state.axtree_history = None + if "thought_history" not in st.session_state: + st.session_state.thought_history = None + if "memory_history" not in st.session_state: + st.session_state.memory_history = None + if "action_history" not in st.session_state: + st.session_state.action_history = None + if "action_info_history" not in st.session_state: + st.session_state.action_info_history = None if "obs_history" not in st.session_state: st.session_state.obs_history = None @@ -185,11 +295,8 @@ def set_task_selector(): st.session_state.task = selected_task_str st.session_state.subtask = selected_subtask_str - # Set empty state tracker - st.session_state.current_action = None - st.session_state.last_obs = None - st.session_state.actions_history = [] - st.session_state.obs_history = [] + reset_env_history() + reset_agent_history() prepare_agent() set_environment_info() @@ -210,17 +317,11 @@ def clean_session(): def prepare_agent(): - logger.info("Preparing agent...") - start = datetime.now() st.session_state.agent_args.prepare() st.session_state.agent = st.session_state.agent_args.make_agent() - end = datetime.now() - logger.info(f"Done in {end - start}") def set_environment_info(): - logger.info("Setting environment info...") - start = datetime.now() action_mapping_fn = get_import_path(st.session_state.agent.action_set.to_python_code) payload = { "benchmark_name": st.session_state.benchmark, @@ -230,17 +331,15 @@ def set_environment_info(): "exp_dir": str(RESULTS_DIR), } resp = requests.post(f"{SERVER_URL}/set_info", json=payload) - if resp.status_code != 200 or resp.json().get("status") != "success": + if resp.status_code != 200 or resp.json().get(Constants.STATUS) != Constants.STATUS_SUCCESS: st.error(resp.json()) - end = datetime.now() - logger.info(f"Done in {end - start}") def prepare_benchmark(): logger.info("Preparing benchmark...") start = datetime.now() resp = requests.post(f"{SERVER_URL}/prepare_benchmark") - if resp.status_code != 200 or resp.json().get("status") != "success": + if resp.status_code != 200 or resp.json().get(Constants.STATUS) != Constants.STATUS_SUCCESS: st.error(resp.json()) end = datetime.now() logger.info(f"Done in {end - start}") @@ -252,33 +351,26 @@ def reset_environment(): resp = requests.post(f"{SERVER_URL}/reset") end = datetime.now() logger.info(f"Done request in {end - start}") - start = datetime.now() - if resp.status_code != 200 or resp.json().get("status") != "success": + if resp.status_code != 200 or resp.json().get(Constants.STATUS) != Constants.STATUS_SUCCESS: logger.error(resp.status_code) - logger.error(resp.json()["status"]) + logger.error(resp.json()[Constants.STATUS]) logger.error(resp.json()["message"]) response_json = resp.json() response_json = deserialize_response(response_json) + obs = response_json[Constants.OBS] if st.session_state.agent.obs_preprocessor: - response_json["obs"] = st.session_state.agent.obs_preprocessor(response_json["obs"]) - st.session_state.last_obs = response_json["obs"] - end = datetime.now() - logger.info(f"Done postproc in {end - start}") + obs = st.session_state.agent.obs_preprocessor(obs) + step_env_history(obs) def reload_task(): logger.info("Reloading task...") start = datetime.now() resp = requests.post(f"{SERVER_URL}/reload_task") - if resp.status_code != 200 or resp.json().get("status") != "success": + if resp.status_code != 200 or resp.json().get(Constants.STATUS) != Constants.STATUS_SUCCESS: logger.error(resp.status_code) - logger.error(resp.json()["status"]) + logger.error(resp.json()[Constants.STATUS]) logger.error(resp.json()["message"]) - response_json = resp.json() - response_json = deserialize_response(response_json) - if st.session_state.agent.obs_preprocessor: - response_json["obs"] = st.session_state.agent.obs_preprocessor(response_json["obs"]) - st.session_state.last_obs = response_json["obs"] end = datetime.now() logger.info(f"Done in {end - start}") @@ -288,16 +380,17 @@ def step_environment(action): start = datetime.now() payload = {"action": action} resp = requests.post(f"{SERVER_URL}/step", json=payload) - if resp.status_code != 200 or resp.json().get("status") != "success": + if resp.status_code != 200 or resp.json().get(Constants.STATUS) != Constants.STATUS_SUCCESS: logger.error(resp.status_code) - logger.error(resp.json()["status"]) + logger.error(resp.json()[Constants.STATUS]) logger.error(resp.json()["message"]) response_json = resp.json() response_json = deserialize_response(response_json) - if st.session_state.agent.obs_preprocessor: - response_json["obs"] = st.session_state.agent.obs_preprocessor(response_json["obs"]) - st.session_state.last_obs = response_json["obs"] + response_json[Constants.OBS] = st.session_state.agent.obs_preprocessor( + response_json[Constants.OBS] + ) + step_env_history(response_json[Constants.OBS]) st.session_state.action = None st.session_state.action_info = None end = datetime.now() @@ -306,16 +399,18 @@ def step_environment(action): def restore_environment(): reload_task() - for action in st.session_state.actions_history: + for action in st.session_state.agent.actions: step_environment(action) + st.session_state.action = st.session_state.action_history[-1] + st.session_state.action_info = st.session_state.action_info_history[-1] + set_agent_state() def get_action(): logger.info("Getting action...") start = datetime.now() action, info = st.session_state.agent.get_action(copy.deepcopy(st.session_state.last_obs)) - st.session_state.action = copy.deepcopy(action) - st.session_state.action_info = copy.deepcopy(info) + step_agent_history(action, info) end = datetime.now() logger.info(f"Done in {end - start}") @@ -347,7 +442,7 @@ def set_agent_state_box(): with st.container(border=True, height=250): st.markdown("**Goal**") st.code( - st.session_state.agent.obs_history[-1]["goal"], + st.session_state.last_obs["goal"], wrap_lines=True, language=None, height=175, @@ -458,105 +553,102 @@ def set_prompt_modifier(): st.session_state.agent.flags.extra_instructions = extra_instructions -def undo_last_agent_step(): - st.session_state.agent.obs_history.pop() # remove last observation - st.session_state.agent.actions.pop() # remove last action - st.session_state.agent.thoughts.pop() # remove last thought - st.session_state.agent.memories.pop() # remove last memory - - def set_controller(): set_agent_state_box() set_prompt_modifier() col_prev, col_redo, col_next = st.columns([1, 1, 1]) with col_prev: - prev_disabled = len(st.session_state.actions_history) == 0 + prev_disabled = len(st.session_state.action_history) <= 1 if st.button("⬅️ Previous Step", disabled=prev_disabled, use_container_width=True): if not prev_disabled: - st.session_state.actions_history.pop() st.session_state.action = ( None - if len(st.session_state.actions_history) == 0 - else st.session_state.actions_history[-1] + if len(st.session_state.action_history) == 0 + else st.session_state.action_history[-1] ) - undo_last_agent_step() - undo_last_agent_step() + reset_agent_state() + revert_agent_history() + revert_env_history() restore_environment() st.rerun() with col_redo: if st.button("🔄 Regenerate Action", use_container_width=True): - undo_last_agent_step() + revert_agent_history() + revert_agent_state() get_action() st.rerun() with col_next: if st.button("➡️ Next Step", use_container_width=True): - st.session_state.actions_history.append(st.session_state.action) step_environment(st.session_state.action) - st.session_state.action = None st.rerun() +def display_image(img_arr): + if isinstance(img_arr, list): + img_arr = np.array(img_arr) + if isinstance(img_arr, np.ndarray): + im = PIL.Image.fromarray(img_arr) + buffered = BytesIO() + im.save(buffered, format="PNG") + img_b64 = base64.b64encode(buffered.getvalue()).decode() + st.markdown( + f'
', + unsafe_allow_html=True, + ) + + def set_screenshot_tab(): - if isinstance(st.session_state.last_obs, dict): - if st.session_state.last_obs.get("screenshot", None) is not None: - img_arr = st.session_state.last_obs["screenshot"] - if isinstance(img_arr, list): - img_arr = np.array(img_arr) - if isinstance(img_arr, np.ndarray): - im = PIL.Image.fromarray(img_arr) - buffered = BytesIO() - im.save(buffered, format="PNG") - img_b64 = base64.b64encode(buffered.getvalue()).decode() - st.markdown( - f'
', - unsafe_allow_html=True, - ) + display_image(st.session_state.screenshot_history[-1]) def set_axtree_tab(): - if isinstance(st.session_state.last_obs, dict): - if st.session_state.last_obs.get("axtree_txt", None) is not None: - st.code(st.session_state.last_obs["axtree_txt"], language=None) + st.code(st.session_state.axtree_history[-1], language=None, wrap_lines=True) def set_prompt_tab(): - if st.session_state.action_info is not None and isinstance( - st.session_state.action_info.chat_messages, Discussion - ): - chat_messages = st.session_state.action_info.chat_messages.messages - new_chat_messages = [] - for message in chat_messages: - if isinstance(message["content"], list): - # concatenate all text elements - new_chat_messages.append( - { - "role": message["role"], - "content": "\n\n".join( - [elem["text"] for elem in message["content"] if elem["type"] == "text"] - ), - } - ) - else: - new_chat_messages.append(message) - st.code( - tokenizer.apply_chat_template( - new_chat_messages, add_special_tokens=True, tokenize=False - ), - wrap_lines=True, - language="markdown", - ) + st.code(st.session_state.prompt_history[-1], language=None, wrap_lines=True) + + +def set_previous_steps_tab(): + for i in range(len(st.session_state.action_history) - 1): + with st.expander(f"### Step {i + 1}", expanded=False): + screenshot_tab, axtree_tab, prompt_tab = st.tabs(["Screenshot", "AxTree", "Prompt"]) + with screenshot_tab: + display_image(st.session_state.screenshot_history[i]) + with axtree_tab: + st.code(st.session_state.axtree_history[i], language=None, wrap_lines=True) + with prompt_tab: + st.code(st.session_state.prompt_history[i], language=None, wrap_lines=True) + st.markdown("**Thought**") + st.code(st.session_state.thought_history[i], language=None, wrap_lines=True) + st.markdown("**Action**") + st.code(st.session_state.action_history[i], language=None, wrap_lines=True) def set_info_tabs(): + print(len(st.session_state.action_history)) + print(len(st.session_state.screenshot_history)) + print(len(st.session_state.axtree_history)) + print(len(st.session_state.prompt_history)) + print(len(st.session_state.thought_history)) + print("---") # Display only if everything is now ready - tab1, tab2, tab3 = st.tabs(["Screenshot", "AxTree", "Prompt"]) + if len(st.session_state.action_history) > 1: + screenshot_tab, axtree_tab, prompt_tab, previous_steps_tab = st.tabs( + ["Screenshot", "AxTree", "Prompt", "Previous Steps"] + ) + else: + screenshot_tab, axtree_tab, prompt_tab = st.tabs(["Screenshot", "AxTree", "Prompt"]) - with tab1: + with screenshot_tab: set_screenshot_tab() - with tab2: + with axtree_tab: set_axtree_tab() - with tab3: + with prompt_tab: set_prompt_tab() + if len(st.session_state.action_history) > 1: + with previous_steps_tab: + set_previous_steps_tab() def run_streamlit(): From 13ee9575bd8a49aaa6235a262b1f5979f313dd0e Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Wed, 18 Jun 2025 21:45:51 -0400 Subject: [PATCH 14/24] add button to go back to arbitrary past step --- src/agentlab/analyze/agent_controller.py | 25 ++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/src/agentlab/analyze/agent_controller.py b/src/agentlab/analyze/agent_controller.py index 50c42ce8..cdb01dc6 100644 --- a/src/agentlab/analyze/agent_controller.py +++ b/src/agentlab/analyze/agent_controller.py @@ -127,6 +127,19 @@ def revert_agent_state(): st.session_state.agent.memories.pop() +def restore_env_history(step: int): + st.session_state.obs_history = st.session_state.obs_history[:step] + st.session_state.screenshot_history = st.session_state.screenshot_history[:step] + st.session_state.axtree_history = st.session_state.axtree_history[:step] + + +def restore_agent_history(step: int): + st.session_state.action_history = st.session_state.action_history[:step] + st.session_state.action_info_history = st.session_state.action_info_history[:step] + st.session_state.thought_history = st.session_state.thought_history[:step] + st.session_state.prompt_history = st.session_state.prompt_history[:step] + + def get_prompt(info): if info is not None and isinstance(info.chat_messages, Discussion): chat_messages = info.chat_messages.messages @@ -612,6 +625,12 @@ def set_prompt_tab(): def set_previous_steps_tab(): for i in range(len(st.session_state.action_history) - 1): with st.expander(f"### Step {i + 1}", expanded=False): + if st.button(f"Go back to step {i + 1}"): + reset_agent_state() + restore_agent_history(step=i + 1) + restore_env_history(step=i + 1) + restore_environment() + st.rerun() screenshot_tab, axtree_tab, prompt_tab = st.tabs(["Screenshot", "AxTree", "Prompt"]) with screenshot_tab: display_image(st.session_state.screenshot_history[i]) @@ -626,12 +645,6 @@ def set_previous_steps_tab(): def set_info_tabs(): - print(len(st.session_state.action_history)) - print(len(st.session_state.screenshot_history)) - print(len(st.session_state.axtree_history)) - print(len(st.session_state.prompt_history)) - print(len(st.session_state.thought_history)) - print("---") # Display only if everything is now ready if len(st.session_state.action_history) > 1: screenshot_tab, axtree_tab, prompt_tab, previous_steps_tab = st.tabs( From f9fdd4ecb7f86b4d0a4c16f565304b7fbb483fbf Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Wed, 18 Jun 2025 22:34:20 -0400 Subject: [PATCH 15/24] implement save feature to save traces and hints --- src/agentlab/analyze/agent_controller.py | 70 +++++++++++++++++++++++- 1 file changed, 67 insertions(+), 3 deletions(-) diff --git a/src/agentlab/analyze/agent_controller.py b/src/agentlab/analyze/agent_controller.py index cdb01dc6..da9444b1 100644 --- a/src/agentlab/analyze/agent_controller.py +++ b/src/agentlab/analyze/agent_controller.py @@ -1,7 +1,9 @@ import base64 import copy import importlib +import json import logging +import os from datetime import datetime from io import BytesIO @@ -48,6 +50,14 @@ def filter(self, record): streamlit_logger.setLevel(logging.ERROR) +def is_json_serializable(value): + try: + json.dumps(value) + return True + except (TypeError, OverflowError): + return False + + def get_import_path(obj): return f"{obj.__module__}.{obj.__qualname__}" @@ -596,7 +606,7 @@ def set_controller(): st.rerun() -def display_image(img_arr): +def get_base64_serialized_image(img_arr): if isinstance(img_arr, list): img_arr = np.array(img_arr) if isinstance(img_arr, np.ndarray): @@ -604,6 +614,13 @@ def display_image(img_arr): buffered = BytesIO() im.save(buffered, format="PNG") img_b64 = base64.b64encode(buffered.getvalue()).decode() + return img_b64 + return None + + +def display_image(img_arr): + img_b64 = get_base64_serialized_image(img_arr) + if img_b64: st.markdown( f'
', unsafe_allow_html=True, @@ -644,11 +661,56 @@ def set_previous_steps_tab(): st.code(st.session_state.action_history[i], language=None, wrap_lines=True) +def set_save_tab(): + # dump full session_state to json + save_dir = st.text_input("Save Directory", value="~/Downloads") + save_dir = os.path.expanduser(save_dir) + if st.button("Save Session State for Current Run"): + now_str = datetime.now().strftime("%Y_%m_%d__%H_%M_%S") + filename = f"agentlab_controller_state_{now_str}.json" + + # prepare payload for saving + payload = {} + payload["timestamp"] = now_str + payload["benchmark"] = st.session_state.benchmark + payload["task"] = st.session_state.task + payload["subtask"] = st.session_state.subtask + payload["agent_args"] = { + k: v for k, v in vars(st.session_state.agent_args).items() if is_json_serializable(v) + } + payload["agent_flags"] = { + k: v for k, v in vars(st.session_state.agent.flags).items() if is_json_serializable(v) + } + payload["agent_flags"]["obs"] = { + k: v + for k, v in vars(st.session_state.agent.flags.obs).items() + if is_json_serializable(v) + } + payload["agent_flags"]["action"] = { + k: v + for k, v in vars(st.session_state.agent.flags.action).items() + if is_json_serializable(v) + } + payload["goal"] = st.session_state.last_obs["goal"] + payload["steps"] = [] + for i in range(len(st.session_state.action_history)): + step = {} + step["action"] = st.session_state.action_history[i] + step["thought"] = st.session_state.thought_history[i] + step["prompt"] = st.session_state.prompt_history[i] + step["screenshot"] = get_base64_serialized_image(st.session_state.screenshot_history[i]) + step["axtree"] = st.session_state.axtree_history[i] + payload["steps"].append(step) + + with open(os.path.join(save_dir, filename), "w") as f: + json.dump(payload, f) + + def set_info_tabs(): # Display only if everything is now ready if len(st.session_state.action_history) > 1: - screenshot_tab, axtree_tab, prompt_tab, previous_steps_tab = st.tabs( - ["Screenshot", "AxTree", "Prompt", "Previous Steps"] + screenshot_tab, axtree_tab, prompt_tab, previous_steps_tab, save_tab = st.tabs( + ["Screenshot", "AxTree", "Prompt", "Previous Steps", "Save"] ) else: screenshot_tab, axtree_tab, prompt_tab = st.tabs(["Screenshot", "AxTree", "Prompt"]) @@ -662,6 +724,8 @@ def set_info_tabs(): if len(st.session_state.action_history) > 1: with previous_steps_tab: set_previous_steps_tab() + with save_tab: + set_save_tab() def run_streamlit(): From b7a54ca8bada5a6beb630cf26d95dae2cc58fa44 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Thu, 19 Jun 2025 09:48:06 -0400 Subject: [PATCH 16/24] add advanced options to go back to step k, reprompt k times, and act k times --- src/agentlab/analyze/agent_controller.py | 87 ++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/src/agentlab/analyze/agent_controller.py b/src/agentlab/analyze/agent_controller.py index da9444b1..c0f4851a 100644 --- a/src/agentlab/analyze/agent_controller.py +++ b/src/agentlab/analyze/agent_controller.py @@ -4,6 +4,7 @@ import json import logging import os +from collections import Counter from datetime import datetime from io import BytesIO @@ -576,6 +577,91 @@ def set_prompt_modifier(): st.session_state.agent.flags.extra_instructions = extra_instructions +def set_advanced_controller(): + with st.expander("**Advanced**", expanded=False): + col_go_back_to, col_reprompt_k, col_act_k = st.columns([1, 1, 1]) + with col_go_back_to: + with st.container(border=True): + st.markdown("**Go Back to Step K**") + col1, col2 = st.columns([1, 1], vertical_alignment="bottom") + is_go_back_to_step_k_disabled = len(st.session_state.action_history) <= 1 + with col1: + step = st.number_input( + "Step", + value=1, + min_value=1, + max_value=len(st.session_state.action_history), + disabled=is_go_back_to_step_k_disabled, + ) + with col2: + if st.button( + "Go Back", + help="Go back to step K", + use_container_width=True, + disabled=is_go_back_to_step_k_disabled, + ): + reset_agent_state() + restore_agent_history(step=step) + restore_env_history(step=step) + restore_environment() + st.rerun() + with col_reprompt_k: + with st.container(border=True): + st.markdown("**Reprompt K Times**") + col1, col2 = st.columns([1, 1], vertical_alignment="bottom") + with col1: + k = st.number_input( + "Number of Generations", + value=5, + min_value=1, + max_value=25, + ) + with col2: + has_clicked_reprompt = st.button( + "Reprompt", + help="Reprompt the agent K times to get a distribution of actions to take", + use_container_width=True, + ) + if has_clicked_reprompt: + reprompt_actions = [] + with st.spinner(f"Reprompting {k} times"): + for i in range(k): + revert_agent_history() + revert_agent_state() + get_action() + reprompt_actions.append(st.session_state.action) + # show all unique actions found in reprompt actions along with their probability + unique_actions_counter = Counter(reprompt_actions) + unique_actions = sorted( + unique_actions_counter.items(), key=lambda x: x[1], reverse=True + ) + st.markdown("**Unique Actions**") + for action, count in unique_actions: + selected_action = st.button(f"`{action}` ({count / k * 100:.2f}%)") + if selected_action: + step_environment(action) + st.rerun() + + with col_act_k: + with st.container(border=True): + st.markdown("**Act K Times**") + col1, col2 = st.columns([1, 1], vertical_alignment="bottom") + with col1: + k = st.number_input("Number of Steps", value=5, min_value=1, max_value=10) + with col2: + has_clicked_act = st.button( + "Act", + help="Let the agent autonomously perform actions for K steps", + use_container_width=True, + ) + if has_clicked_act: + with st.spinner(f"Acting {k} times"): + for _ in range(k): + get_action() + step_environment(st.session_state.action) + st.rerun() + + def set_controller(): set_agent_state_box() set_prompt_modifier() @@ -604,6 +690,7 @@ def set_controller(): if st.button("➡️ Next Step", use_container_width=True): step_environment(st.session_state.action) st.rerun() + set_advanced_controller() def get_base64_serialized_image(img_arr): From 6a516e9980663bae4fd9a0d82d124d77ad8222fe Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Thu, 19 Jun 2025 16:15:25 -0400 Subject: [PATCH 17/24] minor refactoring --- src/agentlab/analyze/agent_controller.py | 226 +++++++++++++---------- 1 file changed, 125 insertions(+), 101 deletions(-) diff --git a/src/agentlab/analyze/agent_controller.py b/src/agentlab/analyze/agent_controller.py index c0f4851a..35740c03 100644 --- a/src/agentlab/analyze/agent_controller.py +++ b/src/agentlab/analyze/agent_controller.py @@ -36,6 +36,7 @@ class Constants: STATUS = "status" STATUS_SUCCESS = "success" STATUS_ERROR = "error" + MESSAGE = "message" OBS = "obs" SCREENSHOT = "screenshot" @@ -378,7 +379,7 @@ def reset_environment(): if resp.status_code != 200 or resp.json().get(Constants.STATUS) != Constants.STATUS_SUCCESS: logger.error(resp.status_code) logger.error(resp.json()[Constants.STATUS]) - logger.error(resp.json()["message"]) + logger.error(resp.json()[Constants.MESSAGE]) response_json = resp.json() response_json = deserialize_response(response_json) obs = response_json[Constants.OBS] @@ -394,7 +395,7 @@ def reload_task(): if resp.status_code != 200 or resp.json().get(Constants.STATUS) != Constants.STATUS_SUCCESS: logger.error(resp.status_code) logger.error(resp.json()[Constants.STATUS]) - logger.error(resp.json()["message"]) + logger.error(resp.json()[Constants.MESSAGE]) end = datetime.now() logger.info(f"Done in {end - start}") @@ -407,7 +408,7 @@ def step_environment(action): if resp.status_code != 200 or resp.json().get(Constants.STATUS) != Constants.STATUS_SUCCESS: logger.error(resp.status_code) logger.error(resp.json()[Constants.STATUS]) - logger.error(resp.json()["message"]) + logger.error(resp.json()[Constants.MESSAGE]) response_json = resp.json() response_json = deserialize_response(response_json) if st.session_state.agent.obs_preprocessor: @@ -577,89 +578,130 @@ def set_prompt_modifier(): st.session_state.agent.flags.extra_instructions = extra_instructions +def set_go_back_to_step_k_section(): + with st.container(border=True): + st.markdown("**Go Back to Step K**") + col1, col2 = st.columns([1, 1], vertical_alignment="bottom") + is_go_back_to_step_k_disabled = len(st.session_state.action_history) <= 1 + with col1: + step = st.number_input( + "Step", + value=1, + min_value=1, + max_value=len(st.session_state.action_history), + disabled=is_go_back_to_step_k_disabled, + ) + with col2: + if st.button( + "Go Back", + help="Go back to step K", + use_container_width=True, + disabled=is_go_back_to_step_k_disabled, + ): + reset_agent_state() + restore_agent_history(step=step) + restore_env_history(step=step) + restore_environment() + st.rerun() + + +def set_reprompt_k_times_section(): + with st.container(border=True): + st.markdown("**Reprompt K Times**") + col1, col2 = st.columns([1, 1], vertical_alignment="bottom") + with col1: + k = st.number_input( + "Number of Generations", + value=5, + min_value=1, + max_value=25, + ) + with col2: + has_clicked_reprompt = st.button( + "Reprompt", + help="Reprompt the agent K times to get a distribution of actions to take", + use_container_width=True, + ) + if has_clicked_reprompt: + reprompt_actions = [] + with st.spinner(f"Reprompting {k} times"): + for i in range(k): + revert_agent_history() + revert_agent_state() + get_action() + reprompt_actions.append(st.session_state.action) + # show all unique actions found in reprompt actions along with their probability + unique_actions_counter = Counter(reprompt_actions) + unique_actions = sorted( + unique_actions_counter.items(), key=lambda x: x[1], reverse=True + ) + st.markdown("**Unique Actions**") + for action, count in unique_actions: + selected_action = st.button(f"`{action}` ({count / k * 100:.2f}%)") + if selected_action: + step_environment(action) + st.rerun() + + +def set_act_k_times_section(): + with st.container(border=True): + st.markdown("**Act K Times**") + col1, col2 = st.columns([1, 1], vertical_alignment="bottom") + with col1: + k = st.number_input("Number of Steps", value=5, min_value=1, max_value=10) + with col2: + has_clicked_act = st.button( + "Act", + help="Let the agent autonomously perform actions for K steps", + use_container_width=True, + ) + if has_clicked_act: + with st.spinner(f"Acting {k} times"): + for _ in range(k): + get_action() + step_environment(st.session_state.action) + st.rerun() + + def set_advanced_controller(): with st.expander("**Advanced**", expanded=False): col_go_back_to, col_reprompt_k, col_act_k = st.columns([1, 1, 1]) with col_go_back_to: - with st.container(border=True): - st.markdown("**Go Back to Step K**") - col1, col2 = st.columns([1, 1], vertical_alignment="bottom") - is_go_back_to_step_k_disabled = len(st.session_state.action_history) <= 1 - with col1: - step = st.number_input( - "Step", - value=1, - min_value=1, - max_value=len(st.session_state.action_history), - disabled=is_go_back_to_step_k_disabled, - ) - with col2: - if st.button( - "Go Back", - help="Go back to step K", - use_container_width=True, - disabled=is_go_back_to_step_k_disabled, - ): - reset_agent_state() - restore_agent_history(step=step) - restore_env_history(step=step) - restore_environment() - st.rerun() + set_go_back_to_step_k_section() with col_reprompt_k: - with st.container(border=True): - st.markdown("**Reprompt K Times**") - col1, col2 = st.columns([1, 1], vertical_alignment="bottom") - with col1: - k = st.number_input( - "Number of Generations", - value=5, - min_value=1, - max_value=25, - ) - with col2: - has_clicked_reprompt = st.button( - "Reprompt", - help="Reprompt the agent K times to get a distribution of actions to take", - use_container_width=True, - ) - if has_clicked_reprompt: - reprompt_actions = [] - with st.spinner(f"Reprompting {k} times"): - for i in range(k): - revert_agent_history() - revert_agent_state() - get_action() - reprompt_actions.append(st.session_state.action) - # show all unique actions found in reprompt actions along with their probability - unique_actions_counter = Counter(reprompt_actions) - unique_actions = sorted( - unique_actions_counter.items(), key=lambda x: x[1], reverse=True - ) - st.markdown("**Unique Actions**") - for action, count in unique_actions: - selected_action = st.button(f"`{action}` ({count / k * 100:.2f}%)") - if selected_action: - step_environment(action) - st.rerun() - + set_reprompt_k_times_section() with col_act_k: - with st.container(border=True): - st.markdown("**Act K Times**") - col1, col2 = st.columns([1, 1], vertical_alignment="bottom") - with col1: - k = st.number_input("Number of Steps", value=5, min_value=1, max_value=10) - with col2: - has_clicked_act = st.button( - "Act", - help="Let the agent autonomously perform actions for K steps", - use_container_width=True, - ) - if has_clicked_act: - with st.spinner(f"Acting {k} times"): - for _ in range(k): - get_action() - step_environment(st.session_state.action) - st.rerun() + set_act_k_times_section() + + +def set_previous_step_section(): + prev_disabled = len(st.session_state.action_history) <= 1 + if st.button("⬅️ Previous Step", disabled=prev_disabled, use_container_width=True): + if not prev_disabled: + st.session_state.action = ( + None + if len(st.session_state.action_history) == 0 + else st.session_state.action_history[-1] + ) + reset_agent_state() + revert_agent_history() + revert_env_history() + restore_environment() + st.rerun() + + +def set_regenerate_action_section(): + if st.button("🔄 Regenerate Action", use_container_width=True): + revert_agent_history() + revert_agent_state() + get_action() + st.rerun() + + +def set_next_step_section(): + if st.button("➡️ Next Step", use_container_width=True): + step_environment(st.session_state.action) + st.rerun() def set_controller(): @@ -667,29 +709,11 @@ def set_controller(): set_prompt_modifier() col_prev, col_redo, col_next = st.columns([1, 1, 1]) with col_prev: - prev_disabled = len(st.session_state.action_history) <= 1 - if st.button("⬅️ Previous Step", disabled=prev_disabled, use_container_width=True): - if not prev_disabled: - st.session_state.action = ( - None - if len(st.session_state.action_history) == 0 - else st.session_state.action_history[-1] - ) - reset_agent_state() - revert_agent_history() - revert_env_history() - restore_environment() - st.rerun() + set_previous_step_section() with col_redo: - if st.button("🔄 Regenerate Action", use_container_width=True): - revert_agent_history() - revert_agent_state() - get_action() - st.rerun() + set_regenerate_action_section() with col_next: - if st.button("➡️ Next Step", use_container_width=True): - step_environment(st.session_state.action) - st.rerun() + set_next_step_section() set_advanced_controller() From 7aa7a09119053327e7bad8b80549f6941341d672 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Sun, 22 Jun 2025 12:31:40 -0400 Subject: [PATCH 18/24] bug fixes --- src/agentlab/analyze/agent_controller.py | 400 ++++++++++++++--------- 1 file changed, 251 insertions(+), 149 deletions(-) diff --git a/src/agentlab/analyze/agent_controller.py b/src/agentlab/analyze/agent_controller.py index 35740c03..a2e90319 100644 --- a/src/agentlab/analyze/agent_controller.py +++ b/src/agentlab/analyze/agent_controller.py @@ -31,6 +31,39 @@ SERVER_URL = "http://127.0.0.1:8000" +# region Sidebar Text +SIDEBAR_TEXT = """ +# AgentLab Controller + +AgentLab Controller is a tool used to help control and debug an agent deployed in an environment. + +AgentLab Controller works by connecting a Streamlit UI that handles the agent to a FastAPI backend server that handles the environment. + +--- + +## Instructions + +1. ⚙️ Setup the task + - Select an agent, benchmark, task, and subtask you want to work on. + - Select "🔄" to reset the environment. This includes resetting the environment server. + - Select "▶️" to start the environment. This will start the environment by opening a browser in the background. This step might take some time + +2. 🎮 Control the environment + - Look at the goal set for the task, the thought of the model, and the action taken. + - If the action looks right, select the "▶️ Next Step" button to step the environment. + + The action will then be executed and the environment will be updated. + - If the action is wrong and you want to re-prompt, select the "🔄 Regenerate Action". + + You can also expand the "Prompt Modifier" menu to change the prompt used to generate the thoughts / actions. + - If you want to backtrack and undo the previous actions, select the "⬅️ Previous Step" button. + + Note: This is a slow process as we need to reset the environment server and perform the previous actions one by one. + +3. 🔎 Investigate the environment + - Look at the screenshot of the current environment state + - Verify that the action selected by the model matches the AxTree + - Ensure that the prompt is properly build. If there are issues with the prompt yielding the wrong action, modify them using the "Prompt Modifier" above. +""" +# endregion + class Constants: STATUS = "status" @@ -52,6 +85,19 @@ def filter(self, record): streamlit_logger.setLevel(logging.ERROR) +def make_hashable(obj): + if isinstance(obj, np.ndarray): + # Use shape, dtype, and bytes for uniqueness + return (obj.shape, obj.dtype.str, obj.tobytes()) + elif isinstance(obj, (tuple, list)): + return tuple(make_hashable(x) for x in obj) + elif isinstance(obj, dict): + # Sort keys to ensure consistent order + return tuple(sorted((k, make_hashable(v)) for k, v in obj.items())) + else: + return obj # Assume it's already hashable + + def is_json_serializable(value): try: json.dumps(value) @@ -78,6 +124,7 @@ def deserialize_response(response_json): def reset_env_history(): + logger.info("Resetting env history") st.session_state.last_obs = None st.session_state.obs_history = [] st.session_state.screenshot_history = [] @@ -85,19 +132,23 @@ def reset_env_history(): def reset_agent_history(): + logger.info("Resetting agent history") st.session_state.action = None st.session_state.action_info = None st.session_state.action_history = [] st.session_state.action_info_history = [] st.session_state.thought_history = [] st.session_state.prompt_history = [] + st.session_state.memory_history = [] def reset_agent_state(): + logger.info("Resetting agent state") st.session_state.agent.reset() def step_env_history(obs): + logger.info("Stepping env history") st.session_state.last_obs = copy.deepcopy(obs) st.session_state.obs_history.append(obs) st.session_state.screenshot_history.append(obs[Constants.SCREENSHOT]) @@ -105,6 +156,7 @@ def step_env_history(obs): def step_agent_history(action, action_info): + logger.info("Stepping agent history") st.session_state.action = copy.deepcopy(action) st.session_state.action_info = copy.deepcopy(action_info) st.session_state.action_history.append(action) @@ -112,27 +164,36 @@ def step_agent_history(action, action_info): st.session_state.thought_history.append(action_info.think) st.session_state.prompt_history.append(get_prompt(action_info)) + # HACK: memory history can only be obtained via the agent + st.session_state.memory_history.append(st.session_state.agent.memories[-1]) + def set_agent_state(): - st.session_state.agent.obs_history = st.session_state.obs_history - st.session_state.agent.actions = st.session_state.action_history - st.session_state.agent.thoughts = st.session_state.thought_history + logger.info("Setting agent state") + st.session_state.agent.obs_history = copy.deepcopy(st.session_state.obs_history) + st.session_state.agent.actions = copy.deepcopy(st.session_state.action_history) + st.session_state.agent.thoughts = copy.deepcopy(st.session_state.thought_history) + st.session_state.agent.memories = copy.deepcopy(st.session_state.memory_history) def revert_env_history(): + logger.info("Reverting env history") st.session_state.obs_history.pop() st.session_state.screenshot_history.pop() st.session_state.axtree_history.pop() def revert_agent_history(): + logger.info("Reverting agent history") st.session_state.action_history.pop() st.session_state.action_info_history.pop() st.session_state.thought_history.pop() st.session_state.prompt_history.pop() + st.session_state.memory_history.pop() def revert_agent_state(): + logger.info("Reverting agent state") st.session_state.agent.obs_history.pop() st.session_state.agent.actions.pop() st.session_state.agent.thoughts.pop() @@ -140,16 +201,21 @@ def revert_agent_state(): def restore_env_history(step: int): - st.session_state.obs_history = st.session_state.obs_history[:step] - st.session_state.screenshot_history = st.session_state.screenshot_history[:step] - st.session_state.axtree_history = st.session_state.axtree_history[:step] + logger.info(f"Restoring env history to step {step}") + st.session_state.obs_history = copy.deepcopy(st.session_state.obs_history[:step]) + st.session_state.screenshot_history = copy.deepcopy(st.session_state.screenshot_history[:step]) + st.session_state.axtree_history = copy.deepcopy(st.session_state.axtree_history[:step]) def restore_agent_history(step: int): - st.session_state.action_history = st.session_state.action_history[:step] - st.session_state.action_info_history = st.session_state.action_info_history[:step] - st.session_state.thought_history = st.session_state.thought_history[:step] - st.session_state.prompt_history = st.session_state.prompt_history[:step] + logger.info(f"Restoring agent history to step {step}") + st.session_state.action_history = copy.deepcopy(st.session_state.action_history[:step]) + st.session_state.action_info_history = copy.deepcopy( + st.session_state.action_info_history[:step] + ) + st.session_state.thought_history = copy.deepcopy(st.session_state.thought_history[:step]) + st.session_state.prompt_history = copy.deepcopy(st.session_state.prompt_history[:step]) + st.session_state.memory_history = copy.deepcopy(st.session_state.memory_history[:step]) def get_prompt(info): @@ -177,38 +243,7 @@ def get_prompt(info): def setup_sidebar(): with st.sidebar: - st.markdown( - """ -# AgentLab Controller - -AgentLab Controller is a tool used to help control and debug an agent deployed in an environment. - -AgentLab Controller works by connecting a Streamlit UI that handles the agent to a FastAPI backend server that handles the environment. - ---- - -## Instructions - -1. ⚙️ Setup the task - - Select an agent, benchmark, task, and subtask you want to work on. - - Select "🔄" to reset the environment. This includes resetting the environment server. - - Select "▶️" to start the environment. This will start the environment by opening a browser in the background. This step might take some time - -2. 🎮 Control the environment - - Look at the goal set for the task, the thought of the model, and the action taken. - - If the action looks right, select the "▶️ Next Step" button to step the environment. - + The action will then be executed and the environment will be updated. - - If the action is wrong and you want to re-prompt, select the "🔄 Regenerate Action". - + You can also expand the "Prompt Modifier" menu to change the prompt used to generate the thoughts / actions. - - If you want to backtrack and undo the previous actions, select the "⬅️ Previous Step" button. - + Note: This is a slow process as we need to reset the environment server and perform the previous actions one by one. - -3. 🔎 Investigate the environment - - Look at the screenshot of the current environment state - - Verify that the action selected by the model matches the AxTree - - Ensure that the prompt is properly build. If there are issues with the prompt yielding the wrong action, modify them using the "Prompt Modifier" above. - """ - ) + st.markdown(SIDEBAR_TEXT) def set_session_state(): @@ -257,6 +292,8 @@ def set_session_state(): st.session_state.has_clicked_prev = False if "has_clicked_next" not in st.session_state: st.session_state.has_clicked_next = False + if "has_clicked_multiple_reprompt" not in st.session_state: + st.session_state.has_clicked_multiple_reprompt = False def select_agent(): @@ -294,39 +331,41 @@ def select_subtask(benchmark, task_str) -> str: def set_task_selector(): """Create task selector form. Allows the user to select the agent, benchmark, task, and subtask to run.""" - with st.form("Task Selector"): - col1, col2, col3, col4, col5, col6 = st.columns( - [2, 2, 4, 2, 1, 1], vertical_alignment="bottom" - ) - with col1: - selected_agent_args = select_agent() - with col2: - selected_benchmark_str = select_benchmark() - selected_benchmark = DEFAULT_BENCHMARKS[selected_benchmark_str]() - with col3: - selected_task_str = select_task(selected_benchmark) - with col4: - selected_subtask_str = select_subtask(selected_benchmark, selected_task_str) - with col5: - if st.form_submit_button("🔄", use_container_width=True): - clean_session() - with col6: - if st.form_submit_button("▶️", use_container_width=True): - - # saving configs related to agent and task - st.session_state.has_submitted_configs = True - st.session_state.agent_args = selected_agent_args - st.session_state.benchmark = selected_benchmark_str - st.session_state.task = selected_task_str - st.session_state.subtask = selected_subtask_str - - reset_env_history() - reset_agent_history() - - prepare_agent() - set_environment_info() - prepare_benchmark() - reset_environment() + with st.container(border=True): + st.markdown("##### ⚙️ Select") + with st.form("Task Selector"): + col1, col2, col3, col4, col5, col6 = st.columns( + [2, 2, 4, 2, 1, 1], vertical_alignment="bottom" + ) + with col1: + selected_agent_args = select_agent() + with col2: + selected_benchmark_str = select_benchmark() + selected_benchmark = DEFAULT_BENCHMARKS[selected_benchmark_str]() + with col3: + selected_task_str = select_task(selected_benchmark) + with col4: + selected_subtask_str = select_subtask(selected_benchmark, selected_task_str) + with col5: + if st.form_submit_button("🔄", use_container_width=True): + clean_session() + with col6: + if st.form_submit_button("▶️", use_container_width=True): + + # saving configs related to agent and task + st.session_state.has_submitted_configs = True + st.session_state.agent_args = selected_agent_args + st.session_state.benchmark = selected_benchmark_str + st.session_state.task = selected_task_str + st.session_state.subtask = selected_subtask_str + + reset_env_history() + reset_agent_history() + + prepare_agent() + set_environment_info() + prepare_benchmark() + reset_environment() def clean_session(): @@ -386,18 +425,28 @@ def reset_environment(): if st.session_state.agent.obs_preprocessor: obs = st.session_state.agent.obs_preprocessor(obs) step_env_history(obs) + st.session_state.action = None + st.session_state.action_info = None def reload_task(): logger.info("Reloading task...") start = datetime.now() resp = requests.post(f"{SERVER_URL}/reload_task") + end = datetime.now() + logger.info(f"Done request in {end - start}") if resp.status_code != 200 or resp.json().get(Constants.STATUS) != Constants.STATUS_SUCCESS: logger.error(resp.status_code) logger.error(resp.json()[Constants.STATUS]) logger.error(resp.json()[Constants.MESSAGE]) - end = datetime.now() - logger.info(f"Done in {end - start}") + response_json = resp.json() + response_json = deserialize_response(response_json) + obs = response_json[Constants.OBS] + if st.session_state.agent.obs_preprocessor: + obs = st.session_state.agent.obs_preprocessor(obs) + step_env_history(obs) + st.session_state.action = None + st.session_state.action_info = None def step_environment(action): @@ -405,26 +454,25 @@ def step_environment(action): start = datetime.now() payload = {"action": action} resp = requests.post(f"{SERVER_URL}/step", json=payload) + end = datetime.now() + logger.info(f"Done request in {end - start}") if resp.status_code != 200 or resp.json().get(Constants.STATUS) != Constants.STATUS_SUCCESS: logger.error(resp.status_code) logger.error(resp.json()[Constants.STATUS]) logger.error(resp.json()[Constants.MESSAGE]) response_json = resp.json() response_json = deserialize_response(response_json) + obs = response_json[Constants.OBS] if st.session_state.agent.obs_preprocessor: - response_json[Constants.OBS] = st.session_state.agent.obs_preprocessor( - response_json[Constants.OBS] - ) - step_env_history(response_json[Constants.OBS]) + obs = st.session_state.agent.obs_preprocessor(obs) + step_env_history(obs) st.session_state.action = None st.session_state.action_info = None - end = datetime.now() - logger.info(f"Done in {end - start}") def restore_environment(): reload_task() - for action in st.session_state.agent.actions: + for action in st.session_state.action_history[:-1]: step_environment(action) st.session_state.action = st.session_state.action_history[-1] st.session_state.action_info = st.session_state.action_info_history[-1] @@ -475,18 +523,32 @@ def set_agent_state_box(): with col2: with st.container(border=True, height=250): st.markdown("**Think**") + initial_think = copy.deepcopy(st.session_state.action_info.think) st.session_state.action_info.think = st.text_area( "Think", st.session_state.action_info.think, height=172, label_visibility="collapsed", ) + if st.session_state.action_info.think != initial_think: + # if thought has been updated, update thought history + st.session_state.thought_history[-1] = copy.deepcopy( + st.session_state.action_info.think + ) + st.session_state.agent.thoughts[-1] = copy.deepcopy( + st.session_state.action_info.think + ) with col3: with st.container(border=True, height=250): st.markdown("**Action**") + initial_action = copy.deepcopy(st.session_state.action) st.session_state.action = st.text_area( "Action", st.session_state.action, height=172, label_visibility="collapsed" ) + if st.session_state.action != initial_action: + # if action has been updated, update action history + st.session_state.action_history[-1] = copy.deepcopy(st.session_state.action) + st.session_state.agent.actions[-1] = copy.deepcopy(st.session_state.action) def set_prompt_modifier(): @@ -578,58 +640,71 @@ def set_prompt_modifier(): st.session_state.agent.flags.extra_instructions = extra_instructions -def set_go_back_to_step_k_section(): +def set_go_back_to_step_n_section(): with st.container(border=True): - st.markdown("**Go Back to Step K**") + st.markdown("**Go Back to Step N**") col1, col2 = st.columns([1, 1], vertical_alignment="bottom") - is_go_back_to_step_k_disabled = len(st.session_state.action_history) <= 1 + is_go_back_to_step_n_disabled = len(st.session_state.action_history) <= 1 with col1: step = st.number_input( "Step", value=1, min_value=1, max_value=len(st.session_state.action_history), - disabled=is_go_back_to_step_k_disabled, + disabled=is_go_back_to_step_n_disabled, ) with col2: if st.button( - "Go Back", - help="Go back to step K", + "⬅️ Go Back", + help="Go back to step N", use_container_width=True, - disabled=is_go_back_to_step_k_disabled, + disabled=is_go_back_to_step_n_disabled, ): + logger.info(f"Going back to step {step}") reset_agent_state() restore_agent_history(step=step) - restore_env_history(step=step) + reset_env_history() restore_environment() st.rerun() -def set_reprompt_k_times_section(): +def set_regenerate_action_n_times_section(): with st.container(border=True): - st.markdown("**Reprompt K Times**") + st.markdown("**Regenerate Action N Times**") col1, col2 = st.columns([1, 1], vertical_alignment="bottom") with col1: - k = st.number_input( - "Number of Generations", + n = st.number_input( + "Number of Actions to Generate", value=5, min_value=1, max_value=25, ) with col2: - has_clicked_reprompt = st.button( - "Reprompt", + st.session_state.has_clicked_multiple_reprompt = st.button( + "🔄 Regenerate", help="Reprompt the agent K times to get a distribution of actions to take", use_container_width=True, ) - if has_clicked_reprompt: + if st.session_state.has_clicked_multiple_reprompt: + logger.info(f"Regenerating action {n} times...") reprompt_actions = [] - with st.spinner(f"Reprompting {k} times"): - for i in range(k): - revert_agent_history() - revert_agent_state() - get_action() - reprompt_actions.append(st.session_state.action) + action_to_info_mapping = {} + action_to_memory_mapping = {} + progress_bar = st.progress(0, text=f"Regenerating action {n} times...") + for i in range(n): + progress_bar.progress((i + 1) / n, text=f"Regenerating action {i + 1} of {n}...") + revert_agent_history() + revert_agent_state() + get_action() + reprompt_actions.append(st.session_state.action) + action_to_info_mapping[st.session_state.action] = copy.deepcopy( + st.session_state.action_info + ) + action_to_memory_mapping[st.session_state.action] = copy.deepcopy( + st.session_state.agent.memories[-1] + ) + progress_bar.progress(1, text=f"Regenerating action {n} times...") + progress_bar.empty() # show all unique actions found in reprompt actions along with their probability unique_actions_counter = Counter(reprompt_actions) unique_actions = sorted( @@ -637,29 +712,48 @@ def set_reprompt_k_times_section(): ) st.markdown("**Unique Actions**") for action, count in unique_actions: - selected_action = st.button(f"`{action}` ({count / k * 100:.2f}%)") - if selected_action: - step_environment(action) + has_clicked_reprompted_action = st.button(f"`{action}` ({count / n * 100:.2f}%)") + if has_clicked_reprompted_action: + logger.info(f"Selected action: {action} -- stepping") + st + revert_agent_history() + revert_agent_state() + + # manually step agent state + st.session_state.agent.obs_history.append( + copy.deepcopy(st.session_state.last_obs) + ) + st.session_state.agent.actions.append(action) + st.session_state.agent.thoughts.append(action_to_info_mapping[action].think) + st.session_state.agent.memories.append(action_to_memory_mapping[action]) + + step_agent_history(action, action_to_info_mapping[action]) + # step_environment(action) + st.session_state.has_clicked_multiple_reprompt = False st.rerun() def set_act_k_times_section(): with st.container(border=True): - st.markdown("**Act K Times**") + st.markdown("**Go Forward N Steps**") col1, col2 = st.columns([1, 1], vertical_alignment="bottom") with col1: - k = st.number_input("Number of Steps", value=5, min_value=1, max_value=10) + n = st.number_input("Number of Steps", value=5, min_value=1, max_value=10) with col2: has_clicked_act = st.button( - "Act", - help="Let the agent autonomously perform actions for K steps", + "➡️ Go Forward", + help="Let the agent autonomously perform actions for N steps", use_container_width=True, ) if has_clicked_act: - with st.spinner(f"Acting {k} times"): - for _ in range(k): + logger.info(f"Going forward {n} steps...") + progress_bar = st.progress(0, text=f"Going forward {n} steps...") + for i in range(n): + if st.session_state.action is None: # so that we don't do it for first step get_action() - step_environment(st.session_state.action) + step_environment(st.session_state.action) + progress_bar.progress((i + 1) / n, text=f"Going forward {i + 1} of {n}...") + progress_bar.empty() st.rerun() @@ -667,9 +761,9 @@ def set_advanced_controller(): with st.expander("**Advanced**", expanded=False): col_go_back_to, col_reprompt_k, col_act_k = st.columns([1, 1, 1]) with col_go_back_to: - set_go_back_to_step_k_section() + set_go_back_to_step_n_section() with col_reprompt_k: - set_reprompt_k_times_section() + set_regenerate_action_n_times_section() with col_act_k: set_act_k_times_section() @@ -678,6 +772,7 @@ def set_previous_step_section(): prev_disabled = len(st.session_state.action_history) <= 1 if st.button("⬅️ Previous Step", disabled=prev_disabled, use_container_width=True): if not prev_disabled: + logger.info("Clicked previous step") st.session_state.action = ( None if len(st.session_state.action_history) == 0 @@ -685,13 +780,14 @@ def set_previous_step_section(): ) reset_agent_state() revert_agent_history() - revert_env_history() + reset_env_history() restore_environment() st.rerun() def set_regenerate_action_section(): if st.button("🔄 Regenerate Action", use_container_width=True): + logger.info("Clicked regenerate action") revert_agent_history() revert_agent_state() get_action() @@ -700,21 +796,24 @@ def set_regenerate_action_section(): def set_next_step_section(): if st.button("➡️ Next Step", use_container_width=True): + logger.info("Clicked next step") step_environment(st.session_state.action) st.rerun() def set_controller(): - set_agent_state_box() - set_prompt_modifier() - col_prev, col_redo, col_next = st.columns([1, 1, 1]) - with col_prev: - set_previous_step_section() - with col_redo: - set_regenerate_action_section() - with col_next: - set_next_step_section() - set_advanced_controller() + with st.container(border=True): + st.markdown("##### 🎮 Control") + set_agent_state_box() + set_prompt_modifier() + col_prev, col_redo, col_next = st.columns([1, 1, 1]) + with col_prev: + set_previous_step_section() + with col_redo: + set_regenerate_action_section() + with col_next: + set_next_step_section() + set_advanced_controller() def get_base64_serialized_image(img_arr): @@ -754,9 +853,10 @@ def set_previous_steps_tab(): for i in range(len(st.session_state.action_history) - 1): with st.expander(f"### Step {i + 1}", expanded=False): if st.button(f"Go back to step {i + 1}"): + logger.info(f"Go back to step {i + 1}") reset_agent_state() restore_agent_history(step=i + 1) - restore_env_history(step=i + 1) + reset_env_history() restore_environment() st.rerun() screenshot_tab, axtree_tab, prompt_tab = st.tabs(["Screenshot", "AxTree", "Prompt"]) @@ -818,25 +918,27 @@ def set_save_tab(): def set_info_tabs(): - # Display only if everything is now ready - if len(st.session_state.action_history) > 1: - screenshot_tab, axtree_tab, prompt_tab, previous_steps_tab, save_tab = st.tabs( - ["Screenshot", "AxTree", "Prompt", "Previous Steps", "Save"] - ) - else: - screenshot_tab, axtree_tab, prompt_tab = st.tabs(["Screenshot", "AxTree", "Prompt"]) + with st.container(border=True): + st.markdown("##### 🔎 Analyze") + # Display only if everything is now ready + if len(st.session_state.action_history) > 1: + screenshot_tab, axtree_tab, prompt_tab, previous_steps_tab, save_tab = st.tabs( + ["Screenshot", "AxTree", "Prompt", "Previous Steps", "Save"] + ) + else: + screenshot_tab, axtree_tab, prompt_tab = st.tabs(["Screenshot", "AxTree", "Prompt"]) - with screenshot_tab: - set_screenshot_tab() - with axtree_tab: - set_axtree_tab() - with prompt_tab: - set_prompt_tab() - if len(st.session_state.action_history) > 1: - with previous_steps_tab: - set_previous_steps_tab() - with save_tab: - set_save_tab() + with screenshot_tab: + set_screenshot_tab() + with axtree_tab: + set_axtree_tab() + with prompt_tab: + set_prompt_tab() + if len(st.session_state.action_history) > 1: + with previous_steps_tab: + set_previous_steps_tab() + with save_tab: + set_save_tab() def run_streamlit(): @@ -860,8 +962,8 @@ def run_streamlit(): if st.session_state.agent is not None: if st.session_state.action is None: get_action() - with st.container(border=True): - set_controller() + + set_controller() set_info_tabs() From 6f07f35193e0921c10dffeb2fbca7174792e1774 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Sat, 5 Jul 2025 11:31:10 -0400 Subject: [PATCH 19/24] update controller --- src/agentlab/analyze/agent_controller.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/agentlab/analyze/agent_controller.py b/src/agentlab/analyze/agent_controller.py index a2e90319..dd91e2f2 100644 --- a/src/agentlab/analyze/agent_controller.py +++ b/src/agentlab/analyze/agent_controller.py @@ -20,7 +20,10 @@ from transformers import AutoTokenizer # used to display prompt. simple chat template from apache 2.0 model -tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") +# tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") +tokenizer = AutoTokenizer.from_pretrained( + "/Users/patrice.bechard/.cache/huggingface/hub/models--HuggingFaceH4--zephyr-7b-beta/snapshots/892b3d7a7b1cf10c7a701c60881cd93df615734c" +) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -803,7 +806,7 @@ def set_next_step_section(): def set_controller(): with st.container(border=True): - st.markdown("##### 🎮 Control") + st.markdown("##### 🕹️ Control") set_agent_state_box() set_prompt_modifier() col_prev, col_redo, col_next = st.columns([1, 1, 1]) @@ -946,12 +949,12 @@ def run_streamlit(): # config page st.set_page_config( page_title="AgentLab Controller", - page_icon="🎮", + page_icon="🕹️", layout="wide", initial_sidebar_state="collapsed", ) st.markdown( - '

🎮 AgentLab Controller 🎮

', unsafe_allow_html=True + '

🕹️ AgentLab Controller 🕹️

', unsafe_allow_html=True ) setup_sidebar() From 21ebdd776fad9d67d6ed43867d5fa8432e815ef9 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Mon, 14 Jul 2025 21:57:38 -0400 Subject: [PATCH 20/24] add ability to save with same format as agentlab-xray --- src/agentlab/analyze/agent_controller.py | 128 +++++++++++++++-------- src/agentlab/analyze/server.py | 13 +++ 2 files changed, 100 insertions(+), 41 deletions(-) diff --git a/src/agentlab/analyze/agent_controller.py b/src/agentlab/analyze/agent_controller.py index dd91e2f2..e82f1396 100644 --- a/src/agentlab/analyze/agent_controller.py +++ b/src/agentlab/analyze/agent_controller.py @@ -4,9 +4,11 @@ import json import logging import os +import pickle from collections import Counter from datetime import datetime from io import BytesIO +from pathlib import Path import numpy as np import PIL.Image @@ -14,6 +16,7 @@ import streamlit as st from agentlab.agents.generic_agent import __all__ as ALL_AGENTS from agentlab.experiments.exp_utils import RESULTS_DIR +from agentlab.experiments.loop import ExpArgs, StepInfo, save_package_versions from agentlab.llm.llm_utils import Discussion from bgym import DEFAULT_BENCHMARKS from dotenv import load_dotenv @@ -133,6 +136,12 @@ def reset_env_history(): st.session_state.screenshot_history = [] st.session_state.axtree_history = [] + # related to env info + st.session_state.reward_history = [] + st.session_state.terminated_history = [] + st.session_state.truncated_history = [] + st.session_state.env_info_history = [] + def reset_agent_history(): logger.info("Resetting agent history") @@ -150,13 +159,19 @@ def reset_agent_state(): st.session_state.agent.reset() -def step_env_history(obs): +def step_env_history(obs, response_json): logger.info("Stepping env history") st.session_state.last_obs = copy.deepcopy(obs) st.session_state.obs_history.append(obs) st.session_state.screenshot_history.append(obs[Constants.SCREENSHOT]) st.session_state.axtree_history.append(obs[Constants.AXTREE_TXT]) + # other relevant info found in response_json + st.session_state.reward_history.append(response_json["reward"]) + st.session_state.terminated_history.append(response_json["terminated"]) + st.session_state.truncated_history.append(response_json["truncated"]) + st.session_state.env_info_history.append(response_json["info"]) + def step_agent_history(action, action_info): logger.info("Stepping agent history") @@ -185,6 +200,12 @@ def revert_env_history(): st.session_state.screenshot_history.pop() st.session_state.axtree_history.pop() + # related to env info + st.session_state.reward_history.pop() + st.session_state.terminated_history.pop() + st.session_state.truncated_history.pop() + st.session_state.env_info_history.pop() + def revert_agent_history(): logger.info("Reverting agent history") @@ -209,6 +230,12 @@ def restore_env_history(step: int): st.session_state.screenshot_history = copy.deepcopy(st.session_state.screenshot_history[:step]) st.session_state.axtree_history = copy.deepcopy(st.session_state.axtree_history[:step]) + # related to env info + st.session_state.reward_history = copy.deepcopy(st.session_state.reward_history[:step]) + st.session_state.terminated_history = copy.deepcopy(st.session_state.terminated_history[:step]) + st.session_state.truncated_history = copy.deepcopy(st.session_state.truncated_history[:step]) + st.session_state.env_info_history = copy.deepcopy(st.session_state.env_info_history[:step]) + def restore_agent_history(step: int): logger.info(f"Restoring agent history to step {step}") @@ -262,6 +289,8 @@ def set_session_state(): st.session_state.task = None if "subtask" not in st.session_state: st.session_state.subtask = None + if "env_args" not in st.session_state: + st.session_state.env_args = None # current state if "agent" not in st.session_state: @@ -290,6 +319,14 @@ def set_session_state(): st.session_state.action_info_history = None if "obs_history" not in st.session_state: st.session_state.obs_history = None + if "reward_history" not in st.session_state: + st.session_state.reward_history = None + if "terminated_history" not in st.session_state: + st.session_state.terminated_history = None + if "truncated_history" not in st.session_state: + st.session_state.truncated_history = None + if "env_info_history" not in st.session_state: + st.session_state.env_info_history = None if "has_clicked_prev" not in st.session_state: st.session_state.has_clicked_prev = False @@ -362,6 +399,13 @@ def set_task_selector(): st.session_state.task = selected_task_str st.session_state.subtask = selected_subtask_str + st.session_state.env_args = [ + elem + for elem in selected_benchmark.env_args_list + if elem.task_name == selected_task_str + and str(elem.task_seed) == str(selected_subtask_str) + ][0] + reset_env_history() reset_agent_history() @@ -423,11 +467,12 @@ def reset_environment(): logger.error(resp.json()[Constants.STATUS]) logger.error(resp.json()[Constants.MESSAGE]) response_json = resp.json() + print(response_json.keys()) response_json = deserialize_response(response_json) obs = response_json[Constants.OBS] if st.session_state.agent.obs_preprocessor: obs = st.session_state.agent.obs_preprocessor(obs) - step_env_history(obs) + step_env_history(obs, response_json) st.session_state.action = None st.session_state.action_info = None @@ -447,7 +492,7 @@ def reload_task(): obs = response_json[Constants.OBS] if st.session_state.agent.obs_preprocessor: obs = st.session_state.agent.obs_preprocessor(obs) - step_env_history(obs) + step_env_history(obs, response_json) st.session_state.action = None st.session_state.action_info = None @@ -468,7 +513,7 @@ def step_environment(action): obs = response_json[Constants.OBS] if st.session_state.agent.obs_preprocessor: obs = st.session_state.agent.obs_preprocessor(obs) - step_env_history(obs) + step_env_history(obs, response_json) st.session_state.action = None st.session_state.action_info = None @@ -880,44 +925,45 @@ def set_save_tab(): save_dir = st.text_input("Save Directory", value="~/Downloads") save_dir = os.path.expanduser(save_dir) if st.button("Save Session State for Current Run"): - now_str = datetime.now().strftime("%Y_%m_%d__%H_%M_%S") - filename = f"agentlab_controller_state_{now_str}.json" - - # prepare payload for saving - payload = {} - payload["timestamp"] = now_str - payload["benchmark"] = st.session_state.benchmark - payload["task"] = st.session_state.task - payload["subtask"] = st.session_state.subtask - payload["agent_args"] = { - k: v for k, v in vars(st.session_state.agent_args).items() if is_json_serializable(v) - } - payload["agent_flags"] = { - k: v for k, v in vars(st.session_state.agent.flags).items() if is_json_serializable(v) - } - payload["agent_flags"]["obs"] = { - k: v - for k, v in vars(st.session_state.agent.flags.obs).items() - if is_json_serializable(v) - } - payload["agent_flags"]["action"] = { - k: v - for k, v in vars(st.session_state.agent.flags.action).items() - if is_json_serializable(v) - } - payload["goal"] = st.session_state.last_obs["goal"] - payload["steps"] = [] + # save everything from the session in a way that is consistent + # with how experiments are saved with AgentLab + + # dir name has this format: 2025-07-14_16-46-47_tooluse-gpt-4-1-on-workarena-l1-task-name-sort + exp_dir = ( + Path(save_dir) + / f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_genericagent_{st.session_state.agent_args.agent_name}_on_{st.session_state.benchmark}_{st.session_state.env_args.task_name}_{st.session_state.env_args.task_name}_{st.session_state.env_args.task_seed}" + ) + exp_dir.mkdir(parents=True, exist_ok=True) + + # save package versions + save_package_versions(exp_dir) + + # create ExpArgs object + exp_args = ExpArgs( + agent_args=st.session_state.agent_args, env_args=st.session_state.env_args + ) + with open(exp_dir / "exp_args.pkl", "wb") as f: + pickle.dump(exp_args, f) + + # create StepInfo object for each step for i in range(len(st.session_state.action_history)): - step = {} - step["action"] = st.session_state.action_history[i] - step["thought"] = st.session_state.thought_history[i] - step["prompt"] = st.session_state.prompt_history[i] - step["screenshot"] = get_base64_serialized_image(st.session_state.screenshot_history[i]) - step["axtree"] = st.session_state.axtree_history[i] - payload["steps"].append(step) - - with open(os.path.join(save_dir, filename), "w") as f: - json.dump(payload, f) + step_info = StepInfo() + step_info.step = i + step_info.obs = st.session_state.obs_history[i] + step_info.reward = st.session_state.reward_history[i] + step_info.terminated = st.session_state.terminated_history[i] + step_info.truncated = st.session_state.truncated_history[i] + step_info.action = st.session_state.action_history[i] + step_info.agent_info = st.session_state.action_info_history[i] + step_info.make_stats() + # TODO: set profiling stats + step_info.task_info = st.session_state.env_info_history[i].get("task_info", None) + step_info.raw_reward = st.session_state.env_info_history[i].get( + "RAW_REWARD_GLOBAL", None + ) + step_info.save_step_info(exp_dir, save_screenshot=True, save_som=True) + + st.success(f"Saved session state at {exp_dir}") def set_info_tabs(): diff --git a/src/agentlab/analyze/server.py b/src/agentlab/analyze/server.py index df3bfbdd..4f82f4c4 100644 --- a/src/agentlab/analyze/server.py +++ b/src/agentlab/analyze/server.py @@ -234,6 +234,10 @@ def status(self) -> dict: { "status": "success", "message": "Environment status retrieved successfully.", + "obs": self.last_obs, + "reward": 0, + "terminated": False, + "truncated": False, "info_set": self.info_set, "env_created": self.env is not None, } @@ -318,6 +322,9 @@ def reload_task(self) -> dict: "status": "success", "message": "Task reloaded successfully.", "obs": self.last_obs, + "reward": 0, + "terminated": False, + "truncated": False, "info": self.last_info, } ) @@ -356,6 +363,9 @@ def reset(self) -> dict: "status": "success", "message": "Environment reset successfully", "obs": self.last_obs, + "reward": 0, + "terminated": False, + "truncated": False, "info": self.last_info, } ) @@ -413,6 +423,9 @@ def get_obs(self) -> dict: "status": "success", "message": "Observation retrieved successfully.", "obs": self.last_obs, + "reward": 0, + "terminated": False, + "truncated": False, "info": self.last_info, } ) From a7702dbdc592b043a9c0938979d4ede6794180e3 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Wed, 16 Jul 2025 13:44:32 -0400 Subject: [PATCH 21/24] support for ToolUseAgent in controller, enable loading of previous run --- .../agents/tool_use_agent/__init__.py | 9 + .../agents/tool_use_agent/tool_use_agent.py | 18 + src/agentlab/analyze/agent_controller.py | 449 +++++++++++++----- src/agentlab/llm/response_api.py | 56 ++- 4 files changed, 408 insertions(+), 124 deletions(-) diff --git a/src/agentlab/agents/tool_use_agent/__init__.py b/src/agentlab/agents/tool_use_agent/__init__.py index 935fea14..2f4a0139 100644 --- a/src/agentlab/agents/tool_use_agent/__init__.py +++ b/src/agentlab/agents/tool_use_agent/__init__.py @@ -4,3 +4,12 @@ # for backward compatibility of unpickling sys.modules[__name__ + ".multi_tool_agent"] = sys.modules[__name__] + +__all__ = [ + "GPT_4_1", + "AZURE_GPT_4_1", + "GPT_4_1_MINI", + "AZURE_GPT_4_1_MINI", + "OPENAI_CHATAPI_MODEL_CONFIG", + "CLAUDE_MODEL_CONFIG", +] diff --git a/src/agentlab/agents/tool_use_agent/tool_use_agent.py b/src/agentlab/agents/tool_use_agent/tool_use_agent.py index b7494693..3ccb0db4 100644 --- a/src/agentlab/agents/tool_use_agent/tool_use_agent.py +++ b/src/agentlab/agents/tool_use_agent/tool_use_agent.py @@ -508,6 +508,15 @@ def get_action(self, obs: Any) -> float: vision_support=True, ) +AZURE_GPT_4_1 = AzureOpenAIResponseModelArgs( + model_name="gpt-4.1", + max_total_tokens=200_000, + max_input_tokens=200_000, + max_new_tokens=2_000, + temperature=0.1, + vision_support=True, +) + GPT_4_1_MINI = OpenAIResponseModelArgs( model_name="gpt-4.1-mini", max_total_tokens=200_000, @@ -517,6 +526,15 @@ def get_action(self, obs: Any) -> float: vision_support=True, ) +AZURE_GPT_4_1_MINI = AzureOpenAIResponseModelArgs( + model_name="gpt-4.1-mini", + max_total_tokens=200_000, + max_input_tokens=200_000, + max_new_tokens=2_000, + temperature=0.1, + vision_support=True, +) + OPENAI_CHATAPI_MODEL_CONFIG = OpenAIChatModelArgs( model_name="gpt-4o-2024-08-06", max_total_tokens=200_000, diff --git a/src/agentlab/analyze/agent_controller.py b/src/agentlab/analyze/agent_controller.py index e82f1396..d0bd9ee0 100644 --- a/src/agentlab/analyze/agent_controller.py +++ b/src/agentlab/analyze/agent_controller.py @@ -1,5 +1,6 @@ import base64 import copy +import gzip import importlib import json import logging @@ -14,7 +15,14 @@ import PIL.Image import requests import streamlit as st -from agentlab.agents.generic_agent import __all__ as ALL_AGENTS +from agentlab.agents.generic_agent import __all__ as ALL_GENERIC_AGENTS +from agentlab.agents.generic_agent.generic_agent import GenericAgent +from agentlab.agents.tool_use_agent import __all__ as ALL_TOOL_USE_AGENTS +from agentlab.agents.tool_use_agent.tool_use_agent import ( + DEFAULT_PROMPT_CONFIG, + ToolUseAgent, + ToolUseAgentArgs, +) from agentlab.experiments.exp_utils import RESULTS_DIR from agentlab.experiments.loop import ExpArgs, StepInfo, save_package_versions from agentlab.llm.llm_utils import Discussion @@ -32,7 +40,6 @@ logger.setLevel(logging.INFO) load_dotenv() -DEFAULT_AGENT = "AGENT_AZURE_4o" DEFAULT_BENCHMARK = "workarena_l1" SERVER_URL = "http://127.0.0.1:8000" @@ -151,7 +158,8 @@ def reset_agent_history(): st.session_state.action_info_history = [] st.session_state.thought_history = [] st.session_state.prompt_history = [] - st.session_state.memory_history = [] + if isinstance(st.session_state.agent, GenericAgent): + st.session_state.memory_history = [] def reset_agent_state(): @@ -183,7 +191,8 @@ def step_agent_history(action, action_info): st.session_state.prompt_history.append(get_prompt(action_info)) # HACK: memory history can only be obtained via the agent - st.session_state.memory_history.append(st.session_state.agent.memories[-1]) + if isinstance(st.session_state.agent, GenericAgent): + st.session_state.memory_history.append(st.session_state.agent.memories[-1]) def set_agent_state(): @@ -191,7 +200,8 @@ def set_agent_state(): st.session_state.agent.obs_history = copy.deepcopy(st.session_state.obs_history) st.session_state.agent.actions = copy.deepcopy(st.session_state.action_history) st.session_state.agent.thoughts = copy.deepcopy(st.session_state.thought_history) - st.session_state.agent.memories = copy.deepcopy(st.session_state.memory_history) + if isinstance(st.session_state.agent, GenericAgent): + st.session_state.agent.memories = copy.deepcopy(st.session_state.memory_history) def revert_env_history(): @@ -213,7 +223,8 @@ def revert_agent_history(): st.session_state.action_info_history.pop() st.session_state.thought_history.pop() st.session_state.prompt_history.pop() - st.session_state.memory_history.pop() + if isinstance(st.session_state.agent, GenericAgent): + st.session_state.memory_history.pop() def revert_agent_state(): @@ -245,30 +256,39 @@ def restore_agent_history(step: int): ) st.session_state.thought_history = copy.deepcopy(st.session_state.thought_history[:step]) st.session_state.prompt_history = copy.deepcopy(st.session_state.prompt_history[:step]) - st.session_state.memory_history = copy.deepcopy(st.session_state.memory_history[:step]) + if isinstance(st.session_state.agent, GenericAgent): + st.session_state.memory_history = copy.deepcopy(st.session_state.memory_history[:step]) def get_prompt(info): - if info is not None and isinstance(info.chat_messages, Discussion): - chat_messages = info.chat_messages.messages - new_chat_messages = [] - for message in chat_messages: - if isinstance(message["content"], list): - # concatenate all text elements - new_chat_messages.append( - { - "role": message["role"], - "content": "\n\n".join( - [elem["text"] for elem in message["content"] if elem["type"] == "text"] - ), - } - ) - else: - new_chat_messages.append(message) - prompt = tokenizer.apply_chat_template( - new_chat_messages, add_special_tokens=True, tokenize=False - ) - return prompt + if info is not None: + if hasattr(info, "chat_messages") and isinstance(info.chat_messages, Discussion): + chat_messages = info.chat_messages.messages + new_chat_messages = [] + for message in chat_messages: + if isinstance(message["content"], list): + # concatenate all text elements + new_chat_messages.append( + { + "role": message["role"], + "content": "\n\n".join( + [ + elem["text"] + for elem in message["content"] + if elem["type"] == "text" + ] + ), + } + ) + else: + new_chat_messages.append(message) + prompt = tokenizer.apply_chat_template( + new_chat_messages, add_special_tokens=True, tokenize=False + ) + return prompt + else: + prompt = "Not implemented yet for Response API" + return prompt def setup_sidebar(): @@ -304,29 +324,38 @@ def set_session_state(): # track history if "prompt_history" not in st.session_state: - st.session_state.prompt_history = None + st.session_state.prompt_history = [] if "screenshot_history" not in st.session_state: - st.session_state.screenshot_history = None + st.session_state.screenshot_history = [] if "axtree_history" not in st.session_state: - st.session_state.axtree_history = None + st.session_state.axtree_history = [] if "thought_history" not in st.session_state: - st.session_state.thought_history = None + st.session_state.thought_history = [] if "memory_history" not in st.session_state: - st.session_state.memory_history = None + st.session_state.memory_history = [] if "action_history" not in st.session_state: - st.session_state.action_history = None + st.session_state.action_history = [] if "action_info_history" not in st.session_state: - st.session_state.action_info_history = None + st.session_state.action_info_history = [] if "obs_history" not in st.session_state: - st.session_state.obs_history = None + st.session_state.obs_history = [] if "reward_history" not in st.session_state: - st.session_state.reward_history = None + st.session_state.reward_history = [] if "terminated_history" not in st.session_state: - st.session_state.terminated_history = None + st.session_state.terminated_history = [] if "truncated_history" not in st.session_state: - st.session_state.truncated_history = None + st.session_state.truncated_history = [] if "env_info_history" not in st.session_state: - st.session_state.env_info_history = None + st.session_state.env_info_history = [] + + if "task_to_benchmark_mapping" not in st.session_state: + st.session_state.task_to_benchmark_mapping = {} + for benchmark in list(DEFAULT_BENCHMARKS.keys()): + all_tasks = set( + [elem.task_name for elem in DEFAULT_BENCHMARKS[benchmark]().env_args_list] + ) + for task in all_tasks: + st.session_state.task_to_benchmark_mapping[task] = benchmark if "has_clicked_prev" not in st.session_state: st.session_state.has_clicked_prev = False @@ -336,11 +365,36 @@ def set_session_state(): st.session_state.has_clicked_multiple_reprompt = False -def select_agent(): +def select_agent_type(): + """Dropdown to select an agent type.""" + agent_type = st.selectbox("Select Agent Type", ["GenericAgent", "ToolUseAgent"], index=0) + return agent_type + + +def select_agent(agent_type: str = "GenericAgent"): """Dropdown to select an agent.""" - agent_str = st.selectbox("Select Agent", ALL_AGENTS, index=ALL_AGENTS.index(DEFAULT_AGENT)) - agents_module = importlib.import_module("agentlab.agents.generic_agent") - agent = getattr(agents_module, agent_str) + if agent_type == "GenericAgent": + agent_choices = ALL_GENERIC_AGENTS + default_agent = "AGENT_AZURE_4o" + agent_str = st.selectbox( + "Select Agent", agent_choices, index=agent_choices.index(default_agent) + ) + agents_module = importlib.import_module("agentlab.agents.generic_agent") + agent = getattr(agents_module, agent_str) + elif agent_type == "ToolUseAgent": + agent_choices = ALL_TOOL_USE_AGENTS + default_agent = "AZURE_GPT_4_1" + agent_str = st.selectbox( + "Select Agent", agent_choices, index=agent_choices.index(default_agent) + ) + agents_module = importlib.import_module("agentlab.agents.tool_use_agent.tool_use_agent") + model_args = getattr(agents_module, agent_str) + agent = ToolUseAgentArgs( + model_args=model_args, + config=copy.deepcopy(DEFAULT_PROMPT_CONFIG), + ) + else: + st.error("Invalid agent type") return agent @@ -374,22 +428,24 @@ def set_task_selector(): with st.container(border=True): st.markdown("##### ⚙️ Select") with st.form("Task Selector"): - col1, col2, col3, col4, col5, col6 = st.columns( - [2, 2, 4, 2, 1, 1], vertical_alignment="bottom" + col1, col2, col3, col4, col5, col6, col7 = st.columns( + [2, 2, 2, 3, 1, 1, 1], vertical_alignment="bottom" ) with col1: - selected_agent_args = select_agent() + selected_agent_type = select_agent_type() with col2: + selected_agent_args = select_agent(selected_agent_type) + with col3: selected_benchmark_str = select_benchmark() selected_benchmark = DEFAULT_BENCHMARKS[selected_benchmark_str]() - with col3: - selected_task_str = select_task(selected_benchmark) with col4: - selected_subtask_str = select_subtask(selected_benchmark, selected_task_str) + selected_task_str = select_task(selected_benchmark) with col5: + selected_subtask_str = select_subtask(selected_benchmark, selected_task_str) + with col6: if st.form_submit_button("🔄", use_container_width=True): clean_session() - with col6: + with col7: if st.form_submit_button("▶️", use_container_width=True): # saving configs related to agent and task @@ -413,6 +469,93 @@ def set_task_selector(): set_environment_info() prepare_benchmark() reset_environment() + # alternatively, one can load a file from disk to load a previous session + with st.expander(label="Load a previous run", expanded=False): + with st.form("Load Previous Run"): + col1, col2 = st.columns( + (11, 1), + vertical_alignment="top", + border=False, + ) + with col1: + exp_files = st.file_uploader( + "Select all files from a previous run directory", + accept_multiple_files=True, + label_visibility="collapsed", + ) + with col2: + if st.form_submit_button( + "⬆️", + use_container_width=True, + ): + if exp_files: + with st.spinner("Loading session..."): + load_session(exp_files) + + +def load_session(exp_files): + logger.info(f"Loading session...") + start = datetime.now() + + # load env and agent args + exp_args_files = [file for file in exp_files if file.name == "exp_args.pkl"] + if len(exp_args_files) == 0: + st.error("No exp_args.pkl file found in the selected directory.") + return + exp_args = exp_args_files[0].getvalue() + exp_args = pickle.loads(exp_args) + st.session_state.agent_args = exp_args.agent_args + st.session_state.env_args = exp_args.env_args + st.session_state.benchmark = st.session_state.task_to_benchmark_mapping[ + exp_args.env_args.task_name + ] + st.session_state.task = exp_args.env_args.task_name + st.session_state.subtask = exp_args.env_args.task_seed + + # load state from step files + screenshot_file_names = [ + file.name for file in exp_files if file.name.startswith("screenshot_step_") + ] + step_files = [file for file in exp_files if file.name.startswith("step_")] + if len(step_files) == 0: + st.error("No step files found in the selected directory.") + return + # sort step files + step_files.sort(key=lambda x: int(x.name.split("_")[-1].split(".")[0])) + # only keep step files for which we have an associated `screenshot_step_n.png` + step_files = [ + file + for file in step_files + if f"screenshot_{file.name.split('.')[0]}.png" in screenshot_file_names + ] + for file in step_files: + with gzip.open(file, "rb") as f: + step_info = pickle.load(f) + st.session_state.action_history.append(step_info.action) + st.session_state.action_info_history.append(step_info.agent_info) + st.session_state.thought_history.append(step_info.agent_info.get("think", None)) + st.session_state.prompt_history.append(get_prompt(step_info.agent_info)) + if isinstance(st.session_state.agent, GenericAgent): + st.session_state.memory_history.append(step_info.agent_info.get("memory", None)) + st.session_state.obs_history.append(step_info.obs) + st.session_state.reward_history.append(step_info.reward) + st.session_state.terminated_history.append(step_info.terminated) + st.session_state.truncated_history.append(step_info.truncated) + st.session_state.env_info_history.append( + {"task_info": step_info.task_info, "RAW_REWARD_GLOBAL": step_info.raw_reward} + ) + st.session_state.last_obs = st.session_state.obs_history[-1] + + # set environment in right state + prepare_agent() + reset_env_history() + set_environment_info() + prepare_benchmark() + reset_environment() + restore_environment() + end = datetime.now() + logger.info(f"Done in {end - start}") + st.rerun() def clean_session(): @@ -430,6 +573,7 @@ def clean_session(): def prepare_agent(): st.session_state.agent_args.prepare() st.session_state.agent = st.session_state.agent_args.make_agent() + st.session_state.agent.set_task_name(st.session_state.task) def set_environment_info(): @@ -467,7 +611,6 @@ def reset_environment(): logger.error(resp.json()[Constants.STATUS]) logger.error(resp.json()[Constants.MESSAGE]) response_json = resp.json() - print(response_json.keys()) response_json = deserialize_response(response_json) obs = response_json[Constants.OBS] if st.session_state.agent.obs_preprocessor: @@ -601,91 +744,161 @@ def set_agent_state_box(): def set_prompt_modifier(): with st.expander("**Prompt Modifier**", expanded=False): - st.markdown("**Observation Flags**") - col1, col2, col3, col4, col5, col6 = st.columns([1, 1, 1, 1, 1, 1]) - with col1: - st.session_state.agent.flags.obs.use_html = st.checkbox( - "use_html", value=st.session_state.agent.flags.obs.use_html - ) - st.session_state.agent.flags.obs.use_action_history = st.checkbox( - "use_action_history", value=st.session_state.agent.flags.obs.use_action_history - ) - with col2: - st.session_state.agent.flags.obs.use_ax_tree = st.checkbox( - "use_ax_tree", value=st.session_state.agent.flags.obs.use_ax_tree - ) - st.session_state.agent.flags.obs.use_think_history = st.checkbox( - "use_think_history", value=st.session_state.agent.flags.obs.use_think_history + if isinstance(st.session_state.agent, GenericAgent): + st.markdown("**Observation Flags**") + col1, col2, col3, col4, col5, col6 = st.columns([1, 1, 1, 1, 1, 1]) + with col1: + st.session_state.agent.flags.obs.use_html = st.checkbox( + "use_html", value=st.session_state.agent.flags.obs.use_html + ) + st.session_state.agent.flags.obs.use_action_history = st.checkbox( + "use_action_history", value=st.session_state.agent.flags.obs.use_action_history + ) + with col2: + st.session_state.agent.flags.obs.use_ax_tree = st.checkbox( + "use_ax_tree", value=st.session_state.agent.flags.obs.use_ax_tree + ) + st.session_state.agent.flags.obs.use_think_history = st.checkbox( + "use_think_history", value=st.session_state.agent.flags.obs.use_think_history + ) + with col3: + st.session_state.agent.flags.obs.use_focused_element = st.checkbox( + "use_focused_element", + value=st.session_state.agent.flags.obs.use_focused_element, + ) + st.session_state.agent.flags.obs.use_diff = st.checkbox( + "use_diff", value=st.session_state.agent.flags.obs.use_diff + ) + with col4: + st.session_state.agent.flags.obs.use_error_logs = st.checkbox( + "use_error_logs", value=st.session_state.agent.flags.obs.use_error_logs + ) + st.session_state.agent.flags.obs.use_screenshot = st.checkbox( + "use_screenshot", value=st.session_state.agent.flags.obs.use_screenshot + ) + with col5: + st.session_state.agent.flags.obs.use_history = st.checkbox( + "use_history", value=st.session_state.agent.flags.obs.use_history + ) + st.session_state.agent.flags.obs.use_som = st.checkbox( + "use_som", value=st.session_state.agent.flags.obs.use_som + ) + with col6: + st.session_state.agent.flags.obs.use_past_error_logs = st.checkbox( + "use_past_error_logs", + value=st.session_state.agent.flags.obs.use_past_error_logs, + ) + st.session_state.agent.flags.obs.use_tabs = st.checkbox( + "use_tabs", value=st.session_state.agent.flags.obs.use_tabs + ) + st.markdown("**Other Flags**") + col1, col2, col3, col4, col5, col6 = st.columns([1, 1, 1, 1, 1, 1]) + with col1: + st.session_state.agent.flags.use_plan = st.checkbox( + "use_plan", value=st.session_state.agent.flags.use_plan + ) + st.session_state.agent.flags.use_hints = st.checkbox( + "use_hints", value=st.session_state.agent.flags.use_hints + ) + with col2: + st.session_state.agent.flags.use_criticise = st.checkbox( + "use_criticise", value=st.session_state.agent.flags.use_criticise + ) + st.session_state.agent.flags.be_cautious = st.checkbox( + "be_cautious", value=st.session_state.agent.flags.be_cautious + ) + with col3: + st.session_state.agent.flags.use_thinking = st.checkbox( + "use_thinking", value=st.session_state.agent.flags.use_thinking + ) + st.session_state.agent.flags.enable_chat = st.checkbox( + "enable_chat", value=st.session_state.agent.flags.enable_chat + ) + with col4: + st.session_state.agent.flags.use_memory = st.checkbox( + "use_memory", value=st.session_state.agent.flags.use_memory + ) + with col5: + st.session_state.agent.flags.use_abstract_example = st.checkbox( + "use_abstract_example", value=st.session_state.agent.flags.use_abstract_example + ) + with col6: + st.session_state.agent.flags.use_concrete_example = st.checkbox( + "use_concrete_example", value=st.session_state.agent.flags.use_concrete_example + ) + extra_instructions = st.text_area( + "extra_instructions", value=st.session_state.agent.flags.extra_instructions ) - with col3: - st.session_state.agent.flags.obs.use_focused_element = st.checkbox( - "use_focused_element", value=st.session_state.agent.flags.obs.use_focused_element + if extra_instructions == "": + extra_instructions = None + st.session_state.agent.flags.extra_instructions = extra_instructions + elif isinstance(st.session_state.agent, ToolUseAgent): + + st.session_state.agent.config.tag_screenshot = st.checkbox( + "Tag screenshot", value=st.session_state.agent.config.tag_screenshot ) - st.session_state.agent.flags.obs.use_diff = st.checkbox( - "use_diff", value=st.session_state.agent.flags.obs.use_diff + + # Goal flags + st.session_state.agent.config.goal.goal_as_system_msg = st.checkbox( + "Goal as system message", + value=st.session_state.agent.config.goal.goal_as_system_msg, ) - with col4: - st.session_state.agent.flags.obs.use_error_logs = st.checkbox( - "use_error_logs", value=st.session_state.agent.flags.obs.use_error_logs + + # Obs flags + st.session_state.agent.config.obs.use_last_error = st.checkbox( + "Use last error", value=st.session_state.agent.config.obs.use_last_error ) - st.session_state.agent.flags.obs.use_screenshot = st.checkbox( - "use_screenshot", value=st.session_state.agent.flags.obs.use_screenshot + st.session_state.agent.config.obs.use_screenshot = st.checkbox( + "Use screenshot", value=st.session_state.agent.config.obs.use_screenshot ) - with col5: - st.session_state.agent.flags.obs.use_history = st.checkbox( - "use_history", value=st.session_state.agent.flags.obs.use_history + st.session_state.agent.config.obs.use_axtree = st.checkbox( + "Use axtree", value=st.session_state.agent.config.obs.use_axtree ) - st.session_state.agent.flags.obs.use_som = st.checkbox( - "use_som", value=st.session_state.agent.flags.obs.use_som + st.session_state.agent.config.obs.use_dom = st.checkbox( + "Use dom", value=st.session_state.agent.config.obs.use_dom ) - with col6: - st.session_state.agent.flags.obs.use_past_error_logs = st.checkbox( - "use_past_error_logs", value=st.session_state.agent.flags.obs.use_past_error_logs + st.session_state.agent.config.obs.use_som = st.checkbox( + "Use som", value=st.session_state.agent.config.obs.use_som ) - st.session_state.agent.flags.obs.use_tabs = st.checkbox( - "use_tabs", value=st.session_state.agent.flags.obs.use_tabs + st.session_state.agent.config.obs.use_tabs = st.checkbox( + "Use tabs", value=st.session_state.agent.config.obs.use_tabs ) - st.markdown("**Other Flags**") - col1, col2, col3, col4, col5, col6 = st.columns([1, 1, 1, 1, 1, 1]) - with col1: - st.session_state.agent.flags.use_plan = st.checkbox( - "use_plan", value=st.session_state.agent.flags.use_plan + st.session_state.agent.config.obs.add_mouse_pointer = st.checkbox( + "Add mouse pointer", value=st.session_state.agent.config.obs.add_mouse_pointer ) - st.session_state.agent.flags.use_hints = st.checkbox( - "use_hints", value=st.session_state.agent.flags.use_hints + st.session_state.agent.config.obs.use_zoomed_webpage = st.checkbox( + "Use zoomed webpage", value=st.session_state.agent.config.obs.use_zoomed_webpage ) - with col2: - st.session_state.agent.flags.use_criticise = st.checkbox( - "use_criticise", value=st.session_state.agent.flags.use_criticise + + # Summarizer flags + st.session_state.agent.config.summarizer.do_summary = st.checkbox( + "Do summary", value=st.session_state.agent.config.summarizer.do_summary ) - st.session_state.agent.flags.be_cautious = st.checkbox( - "be_cautious", value=st.session_state.agent.flags.be_cautious + st.session_state.agent.config.summarizer.high_details = st.checkbox( + "Summarize with high details", + value=st.session_state.agent.config.summarizer.high_details, ) - with col3: - st.session_state.agent.flags.use_thinking = st.checkbox( - "use_thinking", value=st.session_state.agent.flags.use_thinking + + # General Hints flags + st.session_state.agent.config.general_hints.use_hints = st.checkbox( + "Use general hints", value=st.session_state.agent.config.general_hints.use_hints ) - st.session_state.agent.flags.enable_chat = st.checkbox( - "enable_chat", value=st.session_state.agent.flags.enable_chat + + # Task Hint flags + st.session_state.agent.config.task_hint.use_task_hint = st.checkbox( + "Use task hint", value=st.session_state.agent.config.task_hint.use_task_hint ) - with col4: - st.session_state.agent.flags.use_memory = st.checkbox( - "use_memory", value=st.session_state.agent.flags.use_memory + + # general + st.session_state.agent.config.keep_last_n_obs = st.number_input( + "Keep last n obs", value=st.session_state.agent.config.keep_last_n_obs ) - with col5: - st.session_state.agent.flags.use_abstract_example = st.checkbox( - "use_abstract_example", value=st.session_state.agent.flags.use_abstract_example + st.session_state.agent.config.multiaction = st.checkbox( + "Multiaction", value=st.session_state.agent.config.multiaction ) - with col6: - st.session_state.agent.flags.use_concrete_example = st.checkbox( - "use_concrete_example", value=st.session_state.agent.flags.use_concrete_example + st.session_state.agent.config.action_subsets = st.text_area( + "Action subsets", value=st.session_state.agent.config.action_subsets ) - extra_instructions = st.text_area( - "extra_instructions", value=st.session_state.agent.flags.extra_instructions - ) - if extra_instructions == "": - extra_instructions = None - st.session_state.agent.flags.extra_instructions = extra_instructions def set_go_back_to_step_n_section(): diff --git a/src/agentlab/llm/response_api.py b/src/agentlab/llm/response_api.py index 1bbeeebc..6121b392 100644 --- a/src/agentlab/llm/response_api.py +++ b/src/agentlab/llm/response_api.py @@ -4,20 +4,17 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any, Dict, List, Literal, Optional, Union +from urllib.parse import urljoin import openai +from agentlab.llm.llm_utils import image_to_png_base64_url from anthropic import Anthropic from anthropic.types import Completion from anthropic.types import Message as AnthrophicMessage from openai import OpenAI -from agentlab.llm.llm_utils import image_to_png_base64_url - from .base_api import BaseModelArgs -from .llm_utils import ( - call_anthropic_api_with_retries, - call_openai_api_with_retries, -) +from .llm_utils import call_anthropic_api_with_retries, call_openai_api_with_retries from .tracking import TrackAPIPricingMixin """This module contains utlity classes for building input messages and interacting with LLM APIs. @@ -588,6 +585,35 @@ def _extract_env_actions_from_text_response( pass +class AzureOpenAIResponseModel(OpenAIResponseModel): + def __init__( + self, + model_name: str, + api_key: Optional[str] = None, + temperature: float = 0.5, + max_tokens: int = 100, + extra_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ): + api_key = os.getenv("AZURE_OPENAI_API_KEY") + self.tools = kwargs.pop("tools", None) + logging.info(f"Tools: {self.tools}") + super().__init__( + model_name=model_name, + api_key=api_key, + temperature=temperature, + max_tokens=max_tokens, + extra_kwargs=extra_kwargs, + **kwargs, + ) + # azure client takes extra kwargs + self.client = OpenAI( + api_key=api_key, + base_url=urljoin(os.getenv("AZURE_OPENAI_ENDPOINT"), "openai/v1"), + default_query={"api-version": "preview"}, + ) + + class OpenAIChatCompletionModel(BaseModelWithPricing): def __init__( self, @@ -920,6 +946,24 @@ def get_message_builder(self) -> MessageBuilder: return OpenAIResponseAPIMessageBuilder +@dataclass +class AzureOpenAIResponseModelArgs(OpenAIResponseModelArgs): + """Serializable object for instantiating a generic chat model with an Azure OpenAI + model.""" + + api = "openai" + + def make_model(self, extra_kwargs=None, **kwargs): + return AzureOpenAIResponseModel( + model_name=self.model_name, + temperature=self.temperature, + max_tokens=self.max_new_tokens, + extra_kwargs=extra_kwargs, + pricing_api="openai", + **kwargs, + ) + + @dataclass class ClaudeResponseModelArgs(BaseModelArgs): """Serializable object for instantiating a generic chat model with an OpenAI From 71999b40ff01d95209c24b755b3461eca5034f16 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Thu, 17 Jul 2025 09:53:19 -0400 Subject: [PATCH 22/24] update controller --- src/agentlab/analyze/agent_controller.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/agentlab/analyze/agent_controller.py b/src/agentlab/analyze/agent_controller.py index d0bd9ee0..6036b581 100644 --- a/src/agentlab/analyze/agent_controller.py +++ b/src/agentlab/analyze/agent_controller.py @@ -896,9 +896,9 @@ def set_prompt_modifier(): st.session_state.agent.config.multiaction = st.checkbox( "Multiaction", value=st.session_state.agent.config.multiaction ) - st.session_state.agent.config.action_subsets = st.text_area( - "Action subsets", value=st.session_state.agent.config.action_subsets - ) + # st.session_state.agent.config.action_subsets = st.text_area( + # "Action subsets", value=st.session_state.agent.config.action_subsets + # ) def set_go_back_to_step_n_section(): From c41b8efc22e9ab86c6ec5418a1e78895fe619749 Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Mon, 21 Jul 2025 16:14:01 -0400 Subject: [PATCH 23/24] enable reprompt tool use agent from controller --- .../agents/tool_use_agent/hint_db.csv | 22 +++++++ .../agents/tool_use_agent/tool_use_agent.py | 27 ++++---- src/agentlab/analyze/agent_controller.py | 66 +++++++++++++++---- src/agentlab/llm/response_api.py | 38 +++++------ 4 files changed, 105 insertions(+), 48 deletions(-) diff --git a/src/agentlab/agents/tool_use_agent/hint_db.csv b/src/agentlab/agents/tool_use_agent/hint_db.csv index f402c24a..3d52959e 100644 --- a/src/agentlab/agents/tool_use_agent/hint_db.csv +++ b/src/agentlab/agents/tool_use_agent/hint_db.csv @@ -21,3 +21,25 @@ July 13,workarena.servicenow.create-hardware-asset,385,gpt-4.1,ToolUse-gpt-4.1,W July 13,workarena.servicenow.create-hardware-asset,385,gpt-4.1,ToolUse-gpt-4.1,WorkArena-L1,WorkArena-L1,allac,Filling form in WorkArena,"Before clicking submit, make sure that all fields are filled properly. Then click submit." July 13,workarena.servicenow.create-hardware-asset,385,gpt-4.1,ToolUse-gpt-4.1,WorkArena-L1,WorkArena-L1,allac,Filling form in WorkArena,Avoid back and forth from tabs to tabs to reduce the number of actions July 14,workarena.servicenow.create-hardware-asset,385,gpt-4.1,ToolUse-gpt-4.1,WorkArena-L1,WorkArena-L1,allac,Filling form in WorkArena,When you see auto-complete make sure to select an element from that list +July 16,workarena.servicenow.sort-asset-list,406,gpt-4-1,ToolUseAgent-gpt-4-1,workarena,workarena,patricebechard,Sorting lists in ServiceNow,"1. **Navigate to Your Table/List** + + * For example, go to **Incident > All** or any other table you want to view. + +2. **Sort by One or Multiple Columns** + + * `click` on the ""show / hide filter"" button (funnel icon) at the top left of the page to open the filter row. + * Repeat the following steps for each column you want to sort by: + * `click` on the ""Add Sort"" button to add a new sort filter. This will create a new ordering filter row with two comboboxes under the heading ""Order results by the following fields"". + * `fill` the first combobox with the appropriate field name you want to sort by. MAKE SURE to use the exact field name provided. + * `press` Enter after typing the field name. It is VERY IMPORTANT that you do this before doing anything else. DO NOT click on the run filter button before having confirmed your choice by explicitly pressing ENTER. + * `select_option` for the appropriate ordering between ascending (a to z) or descending (z to a) in the second combobox. + * Once all sort filters have been added, `click` the ""Run filter"" button to apply the sort. + +Notes: + * NEVER directly sort the columns using the table header. + * NEVER add columns via the Personalize List menu. + +3. **Resetting or Clearing Sorting** + + * To reset sorting, click another column, or click again to toggle. + * In the filter bar, you may see a ""Sorted by..."" indicator—clear or change it as needed." diff --git a/src/agentlab/agents/tool_use_agent/tool_use_agent.py b/src/agentlab/agents/tool_use_agent/tool_use_agent.py index 3ccb0db4..527882fc 100644 --- a/src/agentlab/agents/tool_use_agent/tool_use_agent.py +++ b/src/agentlab/agents/tool_use_agent/tool_use_agent.py @@ -8,15 +8,6 @@ import bgym import pandas as pd -from bgym import Benchmark as BgymBenchmark -from browsergym.core.observation import extract_screenshot -from browsergym.utils.obs import ( - flatten_axtree_to_str, - flatten_dom_to_str, - overlay_som, - prune_html, -) - from agentlab.agents.agent_args import AgentArgs from agentlab.benchmarks.abstract_env import AbstractBenchmark as AgentLabBenchmark from agentlab.benchmarks.osworld import OSWorldActionSet @@ -24,6 +15,7 @@ from agentlab.llm.llm_utils import image_to_png_base64_url from agentlab.llm.response_api import ( APIPayload, + AzureOpenAIResponseModelArgs, ClaudeResponseModelArgs, LLMOutput, MessageBuilder, @@ -33,6 +25,14 @@ ToolCalls, ) from agentlab.llm.tracking import cost_tracker_decorator +from bgym import Benchmark as BgymBenchmark +from browsergym.core.observation import extract_screenshot +from browsergym.utils.obs import ( + flatten_axtree_to_str, + flatten_dom_to_str, + overlay_som, + prune_html, +) @dataclass @@ -43,8 +43,8 @@ def _init(self): def make(self) -> "Block": """Returns a copy so the init can start adding some stuff to `self` without changing the - original datatclass that should only contain a config. - The aim is avoid having 2 calss definition for each block, e.g. Block and BlockArgs. + original dataclass that should only contain a config. + The aim is avoid having 2 class definitions for each block, e.g. Block and BlockArgs. Returns: Block: A copy of the current block instance with initialization applied. @@ -387,7 +387,6 @@ def __init__( self.config.action_subsets, multiaction=self.config.multiaction # type: ignore ) self.tools = self.action_set.to_tool_description(api=model_args.api) - self.call_ids = [] self.llm = model_args.make_model() @@ -595,8 +594,8 @@ def get_action(self, obs: Any) -> float: task_hint=TaskHint(use_task_hint=True), keep_last_n_obs=None, multiaction=True, # whether to use multi-action or not - # action_subsets=("bid",), - action_subsets=("coord"), + action_subsets=("bid",), + # action_subsets=("coord"), # action_subsets=("coord", "bid"), ) diff --git a/src/agentlab/analyze/agent_controller.py b/src/agentlab/analyze/agent_controller.py index 6036b581..0602d48b 100644 --- a/src/agentlab/analyze/agent_controller.py +++ b/src/agentlab/analyze/agent_controller.py @@ -25,7 +25,7 @@ ) from agentlab.experiments.exp_utils import RESULTS_DIR from agentlab.experiments.loop import ExpArgs, StepInfo, save_package_versions -from agentlab.llm.llm_utils import Discussion +from agentlab.llm.response_api import LLMOutput from bgym import DEFAULT_BENCHMARKS from dotenv import load_dotenv from transformers import AutoTokenizer @@ -188,7 +188,12 @@ def step_agent_history(action, action_info): st.session_state.action_history.append(action) st.session_state.action_info_history.append(action_info) st.session_state.thought_history.append(action_info.think) - st.session_state.prompt_history.append(get_prompt(action_info)) + if isinstance(st.session_state.agent, GenericAgent): + st.session_state.prompt_history.append(get_prompt(action_info)) + elif isinstance(st.session_state.agent, ToolUseAgent): + st.session_state.prompt_history.append( + "\n".join([elem.to_markdown() for elem in st.session_state.agent.discussion.flatten()]) + ) # HACK: memory history can only be obtained via the agent if isinstance(st.session_state.agent, GenericAgent): @@ -229,10 +234,31 @@ def revert_agent_history(): def revert_agent_state(): logger.info("Reverting agent state") - st.session_state.agent.obs_history.pop() - st.session_state.agent.actions.pop() - st.session_state.agent.thoughts.pop() - st.session_state.agent.memories.pop() + if isinstance(st.session_state.agent, GenericAgent): + st.session_state.agent.obs_history.pop() + st.session_state.agent.actions.pop() + st.session_state.agent.thoughts.pop() + st.session_state.agent.memories.pop() + elif isinstance(st.session_state.agent, ToolUseAgent): + num_groups = len(st.session_state.agent.discussion.groups) + if num_groups == 3: + # start from blank state + st.session_state.agent.discussion.groups = [] + st.session_state.agent.last_response = LLMOutput() + st.session_state.agent._responses = [] + elif num_groups > 3: + # get rid of the last group (last action), and remove everything from the other previous group except for the action + st.session_state.agent.discussion.groups.pop() + last_group = copy.deepcopy(st.session_state.agent.discussion.groups[-1]) + last_group.summary = None + last_group.messages = last_group.messages[:0] # remove everything from last group + st.session_state.agent.discussion.groups[-1] = last_group + st.session_state.agent._responses.pop() + st.session_state.agent.last_response = copy.deepcopy( + st.session_state.agent._responses[-1] + ) + else: + raise Exception("Invalid number of groups") def restore_env_history(step: int): @@ -534,9 +560,17 @@ def load_session(exp_files): st.session_state.action_history.append(step_info.action) st.session_state.action_info_history.append(step_info.agent_info) st.session_state.thought_history.append(step_info.agent_info.get("think", None)) - st.session_state.prompt_history.append(get_prompt(step_info.agent_info)) if isinstance(st.session_state.agent, GenericAgent): st.session_state.memory_history.append(step_info.agent_info.get("memory", None)) + st.session_state.prompt_history.append(get_prompt(step_info.agent_info)) + elif isinstance(st.session_state.agent, ToolUseAgent): + st.session_state.prompt_history.append( + "\n".join( + [elem.to_markdown() for elem in st.session_state.agent.discussion.flatten()] + ) + ) + else: + raise ValueError(f"Unknown agent type: {type(st.session_state.agent)}") st.session_state.obs_history.append(step_info.obs) st.session_state.reward_history.append(step_info.reward) st.session_state.terminated_history.append(step_info.terminated) @@ -573,7 +607,8 @@ def clean_session(): def prepare_agent(): st.session_state.agent_args.prepare() st.session_state.agent = st.session_state.agent_args.make_agent() - st.session_state.agent.set_task_name(st.session_state.task) + if isinstance(st.session_state.agent, ToolUseAgent): + st.session_state.agent.set_task_name(st.session_state.task) def set_environment_info(): @@ -863,9 +898,9 @@ def set_prompt_modifier(): st.session_state.agent.config.obs.use_tabs = st.checkbox( "Use tabs", value=st.session_state.agent.config.obs.use_tabs ) - st.session_state.agent.config.obs.add_mouse_pointer = st.checkbox( - "Add mouse pointer", value=st.session_state.agent.config.obs.add_mouse_pointer - ) + # st.session_state.agent.config.obs.add_mouse_pointer = st.checkbox( + # "Add mouse pointer", value=st.session_state.agent.config.obs.add_mouse_pointer + # ) st.session_state.agent.config.obs.use_zoomed_webpage = st.checkbox( "Use zoomed webpage", value=st.session_state.agent.config.obs.use_zoomed_webpage ) @@ -1107,7 +1142,14 @@ def set_axtree_tab(): def set_prompt_tab(): - st.code(st.session_state.prompt_history[-1], language=None, wrap_lines=True) + if isinstance(st.session_state.agent, GenericAgent): + st.code(st.session_state.prompt_history[-1], language=None, wrap_lines=True) + elif isinstance(st.session_state.agent, ToolUseAgent): + st.markdown(st.session_state.prompt_history[-1]) + + st.markdown(f"## Last summary:\n{st.session_state.agent.discussion.get_last_summary()}") + else: + raise ValueError(f"Unknown agent type: {type(st.session_state.agent)}") def set_previous_steps_tab(): diff --git a/src/agentlab/llm/response_api.py b/src/agentlab/llm/response_api.py index 6121b392..0d998d7c 100644 --- a/src/agentlab/llm/response_api.py +++ b/src/agentlab/llm/response_api.py @@ -589,29 +589,26 @@ class AzureOpenAIResponseModel(OpenAIResponseModel): def __init__( self, model_name: str, + base_url: Optional[str] = None, api_key: Optional[str] = None, - temperature: float = 0.5, - max_tokens: int = 100, - extra_kwargs: Optional[Dict[str, Any]] = None, - **kwargs, + temperature: float | None = None, + max_tokens: int | None = 100, ): api_key = os.getenv("AZURE_OPENAI_API_KEY") - self.tools = kwargs.pop("tools", None) - logging.info(f"Tools: {self.tools}") - super().__init__( - model_name=model_name, - api_key=api_key, - temperature=temperature, - max_tokens=max_tokens, - extra_kwargs=extra_kwargs, - **kwargs, - ) - # azure client takes extra kwargs - self.client = OpenAI( - api_key=api_key, - base_url=urljoin(os.getenv("AZURE_OPENAI_ENDPOINT"), "openai/v1"), - default_query={"api-version": "preview"}, + base_url = urljoin(os.getenv("AZURE_OPENAI_ENDPOINT"), "openai/v1") + self.action_space_as_tools = True # this should be a config + super().__init__( # This is passed to BaseModel + model_name=model_name, api_key=api_key, temperature=temperature, max_tokens=max_tokens ) + client_args = {} + if base_url is not None: + client_args["base_url"] = base_url + if api_key is not None: + client_args["api_key"] = api_key + client_args["default_query"] = {"api-version": "preview"} + self.client = OpenAI(**client_args) + # Init pricing tracker after super() so that all attributes have been set. + self.init_pricing_tracker(pricing_api="openai") # Use the PricingMixin class OpenAIChatCompletionModel(BaseModelWithPricing): @@ -958,9 +955,6 @@ def make_model(self, extra_kwargs=None, **kwargs): model_name=self.model_name, temperature=self.temperature, max_tokens=self.max_new_tokens, - extra_kwargs=extra_kwargs, - pricing_api="openai", - **kwargs, ) From ecff4d83b577725ce7648e1744a6d0cbe9030bca Mon Sep 17 00:00:00 2001 From: Patrice Bechard Date: Tue, 22 Jul 2025 11:33:11 -0400 Subject: [PATCH 24/24] updates to agent controller --- src/agentlab/agents/tool_use_agent/hint_db.csv | 6 ++++-- src/agentlab/agents/tool_use_agent/tool_use_agent.py | 2 +- src/agentlab/analyze/agent_controller.py | 11 +++++++---- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/agentlab/agents/tool_use_agent/hint_db.csv b/src/agentlab/agents/tool_use_agent/hint_db.csv index 3d52959e..76ee969d 100644 --- a/src/agentlab/agents/tool_use_agent/hint_db.csv +++ b/src/agentlab/agents/tool_use_agent/hint_db.csv @@ -28,16 +28,18 @@ July 16,workarena.servicenow.sort-asset-list,406,gpt-4-1,ToolUseAgent-gpt-4-1,wo 2. **Sort by One or Multiple Columns** * `click` on the ""show / hide filter"" button (funnel icon) at the top left of the page to open the filter row. - * Repeat the following steps for each column you want to sort by: + * Repeat the following steps for each column you want to sort by in this exact order: * `click` on the ""Add Sort"" button to add a new sort filter. This will create a new ordering filter row with two comboboxes under the heading ""Order results by the following fields"". * `fill` the first combobox with the appropriate field name you want to sort by. MAKE SURE to use the exact field name provided. - * `press` Enter after typing the field name. It is VERY IMPORTANT that you do this before doing anything else. DO NOT click on the run filter button before having confirmed your choice by explicitly pressing ENTER. + * `press` Enter after typing the field name to close the dropdown. It is VERY IMPORTANT that you do this before doing anything else otherwise the field will not be selected and the task will not be successful. DO NOT click on the run filter button before having confirmed your choice by explicitly pressing ENTER. * `select_option` for the appropriate ordering between ascending (a to z) or descending (z to a) in the second combobox. * Once all sort filters have been added, `click` the ""Run filter"" button to apply the sort. Notes: * NEVER directly sort the columns using the table header. * NEVER add columns via the Personalize List menu. + * ALWAYS sort the table using the EXACT NAMES of the provided fields. DO NOT use different but similar field names. For example, if the field you're asked to sort by is ""Opened by"", DO NOT filter by ""Created by"" even if they sound similar, but instead ALWAYS use the exact ""Opened by"" wording. + * Some columns might not appear by default in the visible view of the table. This does not mean they do not exist. ALWAYS use the EXACT names provided to sort by otherwise the task will not be successful. 3. **Resetting or Clearing Sorting** diff --git a/src/agentlab/agents/tool_use_agent/tool_use_agent.py b/src/agentlab/agents/tool_use_agent/tool_use_agent.py index 527882fc..d50a27e8 100644 --- a/src/agentlab/agents/tool_use_agent/tool_use_agent.py +++ b/src/agentlab/agents/tool_use_agent/tool_use_agent.py @@ -593,7 +593,7 @@ def get_action(self, obs: Any) -> float: general_hints=GeneralHints(use_hints=False), task_hint=TaskHint(use_task_hint=True), keep_last_n_obs=None, - multiaction=True, # whether to use multi-action or not + multiaction=False, # whether to use multi-action or not action_subsets=("bid",), # action_subsets=("coord"), # action_subsets=("coord", "bid"), diff --git a/src/agentlab/analyze/agent_controller.py b/src/agentlab/analyze/agent_controller.py index 0602d48b..8179a04b 100644 --- a/src/agentlab/analyze/agent_controller.py +++ b/src/agentlab/analyze/agent_controller.py @@ -164,7 +164,12 @@ def reset_agent_history(): def reset_agent_state(): logger.info("Resetting agent state") - st.session_state.agent.reset() + if isinstance(st.session_state.agent, GenericAgent): + st.session_state.agent.reset() + else: + st.session_state.agent.discussion.groups = [] + st.session_state.agent.last_response = LLMOutput() + st.session_state.agent._responses = [] def step_env_history(obs, response_json): @@ -243,9 +248,7 @@ def revert_agent_state(): num_groups = len(st.session_state.agent.discussion.groups) if num_groups == 3: # start from blank state - st.session_state.agent.discussion.groups = [] - st.session_state.agent.last_response = LLMOutput() - st.session_state.agent._responses = [] + reset_agent_state() elif num_groups > 3: # get rid of the last group (last action), and remove everything from the other previous group except for the action st.session_state.agent.discussion.groups.pop()