Skip to content

Commit f2f3d20

Browse files
format with black line length 100
1 parent 8f1113d commit f2f3d20

File tree

2 files changed

+112
-33
lines changed

2 files changed

+112
-33
lines changed

src/agentlab/analyze/agent_controller.py

Lines changed: 101 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,19 @@
22
import copy
33
import importlib
44
import logging
5+
from datetime import datetime
56
from io import BytesIO
6-
import requests
7+
78
import numpy as np
89
import PIL.Image
10+
import requests
911
import streamlit as st
1012
from agentlab.agents.generic_agent import __all__ as ALL_AGENTS
1113
from agentlab.experiments.exp_utils import RESULTS_DIR
14+
from agentlab.llm.llm_utils import Discussion
1215
from bgym import DEFAULT_BENCHMARKS
1316
from dotenv import load_dotenv
14-
from agentlab.llm.llm_utils import Discussion
1517
from transformers import AutoTokenizer
16-
from datetime import datetime
1718

1819
# used to display prompt. simple chat template from apache 2.0 model
1920
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
@@ -46,7 +47,9 @@ def deserialize_response(response_json):
4647
if "screenshot" in response_json["obs"]:
4748
screenshot_data = response_json["obs"]["screenshot"]
4849
# convert base64 to numpy array
49-
screenshot = np.frombuffer(base64.b64decode(screenshot_data["data"]), dtype=np.dtype(screenshot_data["dtype"]))
50+
screenshot = np.frombuffer(
51+
base64.b64decode(screenshot_data["data"]), dtype=np.dtype(screenshot_data["dtype"])
52+
)
5053
screenshot = screenshot.reshape(screenshot_data["shape"])
5154
response_json["obs"]["screenshot"] = screenshot
5255
return response_json
@@ -132,7 +135,9 @@ def select_agent():
132135
def select_benchmark() -> str:
133136
"""Dropdown to select a benchmark."""
134137
all_benchmarks = list(DEFAULT_BENCHMARKS.keys())
135-
benchmark_str = st.selectbox("Select Benchmark", all_benchmarks, index=all_benchmarks.index(DEFAULT_BENCHMARK))
138+
benchmark_str = st.selectbox(
139+
"Select Benchmark", all_benchmarks, index=all_benchmarks.index(DEFAULT_BENCHMARK)
140+
)
136141
return benchmark_str
137142

138143

@@ -145,15 +150,19 @@ def select_task(benchmark):
145150

146151
def select_subtask(benchmark, task_str) -> str:
147152
"""Dropdown to select a subtask based on the task name."""
148-
all_subtasks = sorted([str(elem.task_seed) for elem in benchmark.env_args_list if elem.task_name == task_str])
153+
all_subtasks = sorted(
154+
[str(elem.task_seed) for elem in benchmark.env_args_list if elem.task_name == task_str]
155+
)
149156
subtask_str = st.selectbox("Select Subtask", all_subtasks)
150157
return subtask_str
151158

152159

153160
def set_task_selector():
154161
"""Create task selector form. Allows the user to select the agent, benchmark, task, and subtask to run."""
155162
with st.form("Task Selector"):
156-
col1, col2, col3, col4, col5, col6 = st.columns([2, 2, 4, 2, 1, 1], vertical_alignment="bottom")
163+
col1, col2, col3, col4, col5, col6 = st.columns(
164+
[2, 2, 4, 2, 1, 1], vertical_alignment="bottom"
165+
)
157166
with col1:
158167
selected_agent_args = select_agent()
159168
with col2:
@@ -339,38 +348,54 @@ def set_agent_state_box():
339348
with col1:
340349
with st.container(border=True, height=250):
341350
st.markdown("**Goal**")
342-
st.code(st.session_state.agent.obs_history[-1]["goal"], wrap_lines=True, language=None, height=175)
351+
st.code(
352+
st.session_state.agent.obs_history[-1]["goal"],
353+
wrap_lines=True,
354+
language=None,
355+
height=175,
356+
)
343357
with col2:
344358
with st.container(border=True, height=250):
345359
st.markdown("**Think**")
346360
st.session_state.action_info.think = st.text_area(
347-
"Think", st.session_state.action_info.think, height=172, label_visibility="collapsed"
361+
"Think",
362+
st.session_state.action_info.think,
363+
height=172,
364+
label_visibility="collapsed",
348365
)
349366
with col3:
350367
with st.container(border=True, height=250):
351368
st.markdown("**Action**")
352-
st.session_state.action = st.text_area("Action", st.session_state.action, height=172, label_visibility="collapsed")
369+
st.session_state.action = st.text_area(
370+
"Action", st.session_state.action, height=172, label_visibility="collapsed"
371+
)
353372

354373

355374
def set_prompt_modifier():
356375
with st.expander("**Prompt Modifier**", expanded=False):
357376
st.markdown("**Observation Flags**")
358377
col1, col2, col3, col4, col5, col6 = st.columns([1, 1, 1, 1, 1, 1])
359378
with col1:
360-
st.session_state.agent.flags.obs.use_html = st.checkbox("use_html", value=st.session_state.agent.flags.obs.use_html)
379+
st.session_state.agent.flags.obs.use_html = st.checkbox(
380+
"use_html", value=st.session_state.agent.flags.obs.use_html
381+
)
361382
st.session_state.agent.flags.obs.use_action_history = st.checkbox(
362383
"use_action_history", value=st.session_state.agent.flags.obs.use_action_history
363384
)
364385
with col2:
365-
st.session_state.agent.flags.obs.use_ax_tree = st.checkbox("use_ax_tree", value=st.session_state.agent.flags.obs.use_ax_tree)
386+
st.session_state.agent.flags.obs.use_ax_tree = st.checkbox(
387+
"use_ax_tree", value=st.session_state.agent.flags.obs.use_ax_tree
388+
)
366389
st.session_state.agent.flags.obs.use_think_history = st.checkbox(
367390
"use_think_history", value=st.session_state.agent.flags.obs.use_think_history
368391
)
369392
with col3:
370393
st.session_state.agent.flags.obs.use_focused_element = st.checkbox(
371394
"use_focused_element", value=st.session_state.agent.flags.obs.use_focused_element
372395
)
373-
st.session_state.agent.flags.obs.use_diff = st.checkbox("use_diff", value=st.session_state.agent.flags.obs.use_diff)
396+
st.session_state.agent.flags.obs.use_diff = st.checkbox(
397+
"use_diff", value=st.session_state.agent.flags.obs.use_diff
398+
)
374399
with col4:
375400
st.session_state.agent.flags.obs.use_error_logs = st.checkbox(
376401
"use_error_logs", value=st.session_state.agent.flags.obs.use_error_logs
@@ -379,26 +404,46 @@ def set_prompt_modifier():
379404
"use_screenshot", value=st.session_state.agent.flags.obs.use_screenshot
380405
)
381406
with col5:
382-
st.session_state.agent.flags.obs.use_history = st.checkbox("use_history", value=st.session_state.agent.flags.obs.use_history)
383-
st.session_state.agent.flags.obs.use_som = st.checkbox("use_som", value=st.session_state.agent.flags.obs.use_som)
407+
st.session_state.agent.flags.obs.use_history = st.checkbox(
408+
"use_history", value=st.session_state.agent.flags.obs.use_history
409+
)
410+
st.session_state.agent.flags.obs.use_som = st.checkbox(
411+
"use_som", value=st.session_state.agent.flags.obs.use_som
412+
)
384413
with col6:
385414
st.session_state.agent.flags.obs.use_past_error_logs = st.checkbox(
386415
"use_past_error_logs", value=st.session_state.agent.flags.obs.use_past_error_logs
387416
)
388-
st.session_state.agent.flags.obs.use_tabs = st.checkbox("use_tabs", value=st.session_state.agent.flags.obs.use_tabs)
417+
st.session_state.agent.flags.obs.use_tabs = st.checkbox(
418+
"use_tabs", value=st.session_state.agent.flags.obs.use_tabs
419+
)
389420
st.markdown("**Other Flags**")
390421
col1, col2, col3, col4, col5, col6 = st.columns([1, 1, 1, 1, 1, 1])
391422
with col1:
392-
st.session_state.agent.flags.use_plan = st.checkbox("use_plan", value=st.session_state.agent.flags.use_plan)
393-
st.session_state.agent.flags.use_hints = st.checkbox("use_hints", value=st.session_state.agent.flags.use_hints)
423+
st.session_state.agent.flags.use_plan = st.checkbox(
424+
"use_plan", value=st.session_state.agent.flags.use_plan
425+
)
426+
st.session_state.agent.flags.use_hints = st.checkbox(
427+
"use_hints", value=st.session_state.agent.flags.use_hints
428+
)
394429
with col2:
395-
st.session_state.agent.flags.use_criticise = st.checkbox("use_criticise", value=st.session_state.agent.flags.use_criticise)
396-
st.session_state.agent.flags.be_cautious = st.checkbox("be_cautious", value=st.session_state.agent.flags.be_cautious)
430+
st.session_state.agent.flags.use_criticise = st.checkbox(
431+
"use_criticise", value=st.session_state.agent.flags.use_criticise
432+
)
433+
st.session_state.agent.flags.be_cautious = st.checkbox(
434+
"be_cautious", value=st.session_state.agent.flags.be_cautious
435+
)
397436
with col3:
398-
st.session_state.agent.flags.use_thinking = st.checkbox("use_thinking", value=st.session_state.agent.flags.use_thinking)
399-
st.session_state.agent.flags.enable_chat = st.checkbox("enable_chat", value=st.session_state.agent.flags.enable_chat)
437+
st.session_state.agent.flags.use_thinking = st.checkbox(
438+
"use_thinking", value=st.session_state.agent.flags.use_thinking
439+
)
440+
st.session_state.agent.flags.enable_chat = st.checkbox(
441+
"enable_chat", value=st.session_state.agent.flags.enable_chat
442+
)
400443
with col4:
401-
st.session_state.agent.flags.use_memory = st.checkbox("use_memory", value=st.session_state.agent.flags.use_memory)
444+
st.session_state.agent.flags.use_memory = st.checkbox(
445+
"use_memory", value=st.session_state.agent.flags.use_memory
446+
)
402447
with col5:
403448
st.session_state.agent.flags.use_abstract_example = st.checkbox(
404449
"use_abstract_example", value=st.session_state.agent.flags.use_abstract_example
@@ -407,7 +452,9 @@ def set_prompt_modifier():
407452
st.session_state.agent.flags.use_concrete_example = st.checkbox(
408453
"use_concrete_example", value=st.session_state.agent.flags.use_concrete_example
409454
)
410-
extra_instructions = st.text_area("extra_instructions", value=st.session_state.agent.flags.extra_instructions)
455+
extra_instructions = st.text_area(
456+
"extra_instructions", value=st.session_state.agent.flags.extra_instructions
457+
)
411458
if extra_instructions == "":
412459
extra_instructions = None
413460
st.session_state.agent.flags.extra_instructions = extra_instructions
@@ -429,7 +476,11 @@ def set_controller():
429476
if st.button("⬅️ Previous Step", disabled=prev_disabled, use_container_width=True):
430477
if not prev_disabled:
431478
st.session_state.actions_history.pop()
432-
st.session_state.action = None if len(st.session_state.actions_history) == 0 else st.session_state.actions_history[-1]
479+
st.session_state.action = (
480+
None
481+
if len(st.session_state.actions_history) == 0
482+
else st.session_state.actions_history[-1]
483+
)
433484
undo_last_agent_step()
434485
undo_last_agent_step()
435486
restore_environment()
@@ -471,18 +522,31 @@ def set_axtree_tab():
471522

472523

473524
def set_prompt_tab():
474-
if st.session_state.action_info is not None and isinstance(st.session_state.action_info.chat_messages, Discussion):
525+
if st.session_state.action_info is not None and isinstance(
526+
st.session_state.action_info.chat_messages, Discussion
527+
):
475528
chat_messages = st.session_state.action_info.chat_messages.messages
476529
new_chat_messages = []
477530
for message in chat_messages:
478531
if isinstance(message["content"], list):
479532
# concatenate all text elements
480533
new_chat_messages.append(
481-
{"role": message["role"], "content": "\n\n".join([elem["text"] for elem in message["content"] if elem["type"] == "text"])}
534+
{
535+
"role": message["role"],
536+
"content": "\n\n".join(
537+
[elem["text"] for elem in message["content"] if elem["type"] == "text"]
538+
),
539+
}
482540
)
483541
else:
484542
new_chat_messages.append(message)
485-
st.code(tokenizer.apply_chat_template(new_chat_messages, add_special_tokens=True, tokenize=False), wrap_lines=True, language="markdown")
543+
st.code(
544+
tokenizer.apply_chat_template(
545+
new_chat_messages, add_special_tokens=True, tokenize=False
546+
),
547+
wrap_lines=True,
548+
language="markdown",
549+
)
486550

487551

488552
def set_info_tabs():
@@ -500,8 +564,15 @@ def set_info_tabs():
500564
def run_streamlit():
501565

502566
# config page
503-
st.set_page_config(page_title="AgentLab Controller", page_icon="🎮", layout="wide", initial_sidebar_state="collapsed")
504-
st.markdown('<h1 style="text-align: center;">🎮 AgentLab Controller 🎮</h1>', unsafe_allow_html=True)
567+
st.set_page_config(
568+
page_title="AgentLab Controller",
569+
page_icon="🎮",
570+
layout="wide",
571+
initial_sidebar_state="collapsed",
572+
)
573+
st.markdown(
574+
'<h1 style="text-align: center;">🎮 AgentLab Controller 🎮</h1>', unsafe_allow_html=True
575+
)
505576

506577
setup_sidebar()
507578

src/agentlab/analyze/server.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,11 @@ def make_json_safe(obj: Any) -> Any:
7272
"""
7373
if isinstance(obj, np.ndarray):
7474
# convert to base64
75-
return {"data": base64.b64encode(obj.tobytes()).decode("utf-8"), "shape": obj.shape, "dtype": str(obj.dtype)}
75+
return {
76+
"data": base64.b64encode(obj.tobytes()).decode("utf-8"),
77+
"shape": obj.shape,
78+
"dtype": str(obj.dtype),
79+
}
7680
elif isinstance(obj, dict):
7781
return {k: make_json_safe(v) for k, v in obj.items()}
7882
elif isinstance(obj, (list, tuple)):
@@ -258,7 +262,9 @@ def prepare_benchmark(self) -> dict:
258262
# prepare backends
259263
benchmark = DEFAULT_BENCHMARKS[self.benchmark_name]()
260264
benchmark.env_args_list = [
261-
elem for elem in benchmark.env_args_list if elem.task_name == self.task_name and str(elem.task_seed) == str(self.seed)
265+
elem
266+
for elem in benchmark.env_args_list
267+
if elem.task_name == self.task_name and str(elem.task_seed) == str(self.seed)
262268
]
263269
benchmark.prepare_backends()
264270

@@ -300,7 +306,9 @@ def reload_task(self) -> dict:
300306
# NOTE: this is not guaranteed to result in the exact same state, but we find that it works most of the time, is much
301307
# faster than resetting the whole environment, and ensures the seed of the environment remains the same
302308
self.env.unwrapped.page.goto(self.start_url, wait_until="load")
303-
self.env.unwrapped.page.evaluate("window.localStorage.clear(); window.sessionStorage.clear();")
309+
self.env.unwrapped.page.evaluate(
310+
"window.localStorage.clear(); window.sessionStorage.clear();"
311+
)
304312
obs = self.env.unwrapped._get_obs()
305313

306314
self.last_obs = copy.deepcopy(obs)

0 commit comments

Comments
 (0)