22import copy
33import importlib
44import logging
5+ from datetime import datetime
56from io import BytesIO
6- import requests
7+
78import numpy as np
89import PIL .Image
10+ import requests
911import streamlit as st
1012from agentlab .agents .generic_agent import __all__ as ALL_AGENTS
1113from agentlab .experiments .exp_utils import RESULTS_DIR
14+ from agentlab .llm .llm_utils import Discussion
1215from bgym import DEFAULT_BENCHMARKS
1316from dotenv import load_dotenv
14- from agentlab .llm .llm_utils import Discussion
1517from transformers import AutoTokenizer
16- from datetime import datetime
1718
1819# used to display prompt. simple chat template from apache 2.0 model
1920tokenizer = AutoTokenizer .from_pretrained ("HuggingFaceH4/zephyr-7b-beta" )
@@ -46,7 +47,9 @@ def deserialize_response(response_json):
4647 if "screenshot" in response_json ["obs" ]:
4748 screenshot_data = response_json ["obs" ]["screenshot" ]
4849 # convert base64 to numpy array
49- screenshot = np .frombuffer (base64 .b64decode (screenshot_data ["data" ]), dtype = np .dtype (screenshot_data ["dtype" ]))
50+ screenshot = np .frombuffer (
51+ base64 .b64decode (screenshot_data ["data" ]), dtype = np .dtype (screenshot_data ["dtype" ])
52+ )
5053 screenshot = screenshot .reshape (screenshot_data ["shape" ])
5154 response_json ["obs" ]["screenshot" ] = screenshot
5255 return response_json
@@ -132,7 +135,9 @@ def select_agent():
132135def select_benchmark () -> str :
133136 """Dropdown to select a benchmark."""
134137 all_benchmarks = list (DEFAULT_BENCHMARKS .keys ())
135- benchmark_str = st .selectbox ("Select Benchmark" , all_benchmarks , index = all_benchmarks .index (DEFAULT_BENCHMARK ))
138+ benchmark_str = st .selectbox (
139+ "Select Benchmark" , all_benchmarks , index = all_benchmarks .index (DEFAULT_BENCHMARK )
140+ )
136141 return benchmark_str
137142
138143
@@ -145,15 +150,19 @@ def select_task(benchmark):
145150
146151def select_subtask (benchmark , task_str ) -> str :
147152 """Dropdown to select a subtask based on the task name."""
148- all_subtasks = sorted ([str (elem .task_seed ) for elem in benchmark .env_args_list if elem .task_name == task_str ])
153+ all_subtasks = sorted (
154+ [str (elem .task_seed ) for elem in benchmark .env_args_list if elem .task_name == task_str ]
155+ )
149156 subtask_str = st .selectbox ("Select Subtask" , all_subtasks )
150157 return subtask_str
151158
152159
153160def set_task_selector ():
154161 """Create task selector form. Allows the user to select the agent, benchmark, task, and subtask to run."""
155162 with st .form ("Task Selector" ):
156- col1 , col2 , col3 , col4 , col5 , col6 = st .columns ([2 , 2 , 4 , 2 , 1 , 1 ], vertical_alignment = "bottom" )
163+ col1 , col2 , col3 , col4 , col5 , col6 = st .columns (
164+ [2 , 2 , 4 , 2 , 1 , 1 ], vertical_alignment = "bottom"
165+ )
157166 with col1 :
158167 selected_agent_args = select_agent ()
159168 with col2 :
@@ -339,38 +348,54 @@ def set_agent_state_box():
339348 with col1 :
340349 with st .container (border = True , height = 250 ):
341350 st .markdown ("**Goal**" )
342- st .code (st .session_state .agent .obs_history [- 1 ]["goal" ], wrap_lines = True , language = None , height = 175 )
351+ st .code (
352+ st .session_state .agent .obs_history [- 1 ]["goal" ],
353+ wrap_lines = True ,
354+ language = None ,
355+ height = 175 ,
356+ )
343357 with col2 :
344358 with st .container (border = True , height = 250 ):
345359 st .markdown ("**Think**" )
346360 st .session_state .action_info .think = st .text_area (
347- "Think" , st .session_state .action_info .think , height = 172 , label_visibility = "collapsed"
361+ "Think" ,
362+ st .session_state .action_info .think ,
363+ height = 172 ,
364+ label_visibility = "collapsed" ,
348365 )
349366 with col3 :
350367 with st .container (border = True , height = 250 ):
351368 st .markdown ("**Action**" )
352- st .session_state .action = st .text_area ("Action" , st .session_state .action , height = 172 , label_visibility = "collapsed" )
369+ st .session_state .action = st .text_area (
370+ "Action" , st .session_state .action , height = 172 , label_visibility = "collapsed"
371+ )
353372
354373
355374def set_prompt_modifier ():
356375 with st .expander ("**Prompt Modifier**" , expanded = False ):
357376 st .markdown ("**Observation Flags**" )
358377 col1 , col2 , col3 , col4 , col5 , col6 = st .columns ([1 , 1 , 1 , 1 , 1 , 1 ])
359378 with col1 :
360- st .session_state .agent .flags .obs .use_html = st .checkbox ("use_html" , value = st .session_state .agent .flags .obs .use_html )
379+ st .session_state .agent .flags .obs .use_html = st .checkbox (
380+ "use_html" , value = st .session_state .agent .flags .obs .use_html
381+ )
361382 st .session_state .agent .flags .obs .use_action_history = st .checkbox (
362383 "use_action_history" , value = st .session_state .agent .flags .obs .use_action_history
363384 )
364385 with col2 :
365- st .session_state .agent .flags .obs .use_ax_tree = st .checkbox ("use_ax_tree" , value = st .session_state .agent .flags .obs .use_ax_tree )
386+ st .session_state .agent .flags .obs .use_ax_tree = st .checkbox (
387+ "use_ax_tree" , value = st .session_state .agent .flags .obs .use_ax_tree
388+ )
366389 st .session_state .agent .flags .obs .use_think_history = st .checkbox (
367390 "use_think_history" , value = st .session_state .agent .flags .obs .use_think_history
368391 )
369392 with col3 :
370393 st .session_state .agent .flags .obs .use_focused_element = st .checkbox (
371394 "use_focused_element" , value = st .session_state .agent .flags .obs .use_focused_element
372395 )
373- st .session_state .agent .flags .obs .use_diff = st .checkbox ("use_diff" , value = st .session_state .agent .flags .obs .use_diff )
396+ st .session_state .agent .flags .obs .use_diff = st .checkbox (
397+ "use_diff" , value = st .session_state .agent .flags .obs .use_diff
398+ )
374399 with col4 :
375400 st .session_state .agent .flags .obs .use_error_logs = st .checkbox (
376401 "use_error_logs" , value = st .session_state .agent .flags .obs .use_error_logs
@@ -379,26 +404,46 @@ def set_prompt_modifier():
379404 "use_screenshot" , value = st .session_state .agent .flags .obs .use_screenshot
380405 )
381406 with col5 :
382- st .session_state .agent .flags .obs .use_history = st .checkbox ("use_history" , value = st .session_state .agent .flags .obs .use_history )
383- st .session_state .agent .flags .obs .use_som = st .checkbox ("use_som" , value = st .session_state .agent .flags .obs .use_som )
407+ st .session_state .agent .flags .obs .use_history = st .checkbox (
408+ "use_history" , value = st .session_state .agent .flags .obs .use_history
409+ )
410+ st .session_state .agent .flags .obs .use_som = st .checkbox (
411+ "use_som" , value = st .session_state .agent .flags .obs .use_som
412+ )
384413 with col6 :
385414 st .session_state .agent .flags .obs .use_past_error_logs = st .checkbox (
386415 "use_past_error_logs" , value = st .session_state .agent .flags .obs .use_past_error_logs
387416 )
388- st .session_state .agent .flags .obs .use_tabs = st .checkbox ("use_tabs" , value = st .session_state .agent .flags .obs .use_tabs )
417+ st .session_state .agent .flags .obs .use_tabs = st .checkbox (
418+ "use_tabs" , value = st .session_state .agent .flags .obs .use_tabs
419+ )
389420 st .markdown ("**Other Flags**" )
390421 col1 , col2 , col3 , col4 , col5 , col6 = st .columns ([1 , 1 , 1 , 1 , 1 , 1 ])
391422 with col1 :
392- st .session_state .agent .flags .use_plan = st .checkbox ("use_plan" , value = st .session_state .agent .flags .use_plan )
393- st .session_state .agent .flags .use_hints = st .checkbox ("use_hints" , value = st .session_state .agent .flags .use_hints )
423+ st .session_state .agent .flags .use_plan = st .checkbox (
424+ "use_plan" , value = st .session_state .agent .flags .use_plan
425+ )
426+ st .session_state .agent .flags .use_hints = st .checkbox (
427+ "use_hints" , value = st .session_state .agent .flags .use_hints
428+ )
394429 with col2 :
395- st .session_state .agent .flags .use_criticise = st .checkbox ("use_criticise" , value = st .session_state .agent .flags .use_criticise )
396- st .session_state .agent .flags .be_cautious = st .checkbox ("be_cautious" , value = st .session_state .agent .flags .be_cautious )
430+ st .session_state .agent .flags .use_criticise = st .checkbox (
431+ "use_criticise" , value = st .session_state .agent .flags .use_criticise
432+ )
433+ st .session_state .agent .flags .be_cautious = st .checkbox (
434+ "be_cautious" , value = st .session_state .agent .flags .be_cautious
435+ )
397436 with col3 :
398- st .session_state .agent .flags .use_thinking = st .checkbox ("use_thinking" , value = st .session_state .agent .flags .use_thinking )
399- st .session_state .agent .flags .enable_chat = st .checkbox ("enable_chat" , value = st .session_state .agent .flags .enable_chat )
437+ st .session_state .agent .flags .use_thinking = st .checkbox (
438+ "use_thinking" , value = st .session_state .agent .flags .use_thinking
439+ )
440+ st .session_state .agent .flags .enable_chat = st .checkbox (
441+ "enable_chat" , value = st .session_state .agent .flags .enable_chat
442+ )
400443 with col4 :
401- st .session_state .agent .flags .use_memory = st .checkbox ("use_memory" , value = st .session_state .agent .flags .use_memory )
444+ st .session_state .agent .flags .use_memory = st .checkbox (
445+ "use_memory" , value = st .session_state .agent .flags .use_memory
446+ )
402447 with col5 :
403448 st .session_state .agent .flags .use_abstract_example = st .checkbox (
404449 "use_abstract_example" , value = st .session_state .agent .flags .use_abstract_example
@@ -407,7 +452,9 @@ def set_prompt_modifier():
407452 st .session_state .agent .flags .use_concrete_example = st .checkbox (
408453 "use_concrete_example" , value = st .session_state .agent .flags .use_concrete_example
409454 )
410- extra_instructions = st .text_area ("extra_instructions" , value = st .session_state .agent .flags .extra_instructions )
455+ extra_instructions = st .text_area (
456+ "extra_instructions" , value = st .session_state .agent .flags .extra_instructions
457+ )
411458 if extra_instructions == "" :
412459 extra_instructions = None
413460 st .session_state .agent .flags .extra_instructions = extra_instructions
@@ -429,7 +476,11 @@ def set_controller():
429476 if st .button ("⬅️ Previous Step" , disabled = prev_disabled , use_container_width = True ):
430477 if not prev_disabled :
431478 st .session_state .actions_history .pop ()
432- st .session_state .action = None if len (st .session_state .actions_history ) == 0 else st .session_state .actions_history [- 1 ]
479+ st .session_state .action = (
480+ None
481+ if len (st .session_state .actions_history ) == 0
482+ else st .session_state .actions_history [- 1 ]
483+ )
433484 undo_last_agent_step ()
434485 undo_last_agent_step ()
435486 restore_environment ()
@@ -471,18 +522,31 @@ def set_axtree_tab():
471522
472523
473524def set_prompt_tab ():
474- if st .session_state .action_info is not None and isinstance (st .session_state .action_info .chat_messages , Discussion ):
525+ if st .session_state .action_info is not None and isinstance (
526+ st .session_state .action_info .chat_messages , Discussion
527+ ):
475528 chat_messages = st .session_state .action_info .chat_messages .messages
476529 new_chat_messages = []
477530 for message in chat_messages :
478531 if isinstance (message ["content" ], list ):
479532 # concatenate all text elements
480533 new_chat_messages .append (
481- {"role" : message ["role" ], "content" : "\n \n " .join ([elem ["text" ] for elem in message ["content" ] if elem ["type" ] == "text" ])}
534+ {
535+ "role" : message ["role" ],
536+ "content" : "\n \n " .join (
537+ [elem ["text" ] for elem in message ["content" ] if elem ["type" ] == "text" ]
538+ ),
539+ }
482540 )
483541 else :
484542 new_chat_messages .append (message )
485- st .code (tokenizer .apply_chat_template (new_chat_messages , add_special_tokens = True , tokenize = False ), wrap_lines = True , language = "markdown" )
543+ st .code (
544+ tokenizer .apply_chat_template (
545+ new_chat_messages , add_special_tokens = True , tokenize = False
546+ ),
547+ wrap_lines = True ,
548+ language = "markdown" ,
549+ )
486550
487551
488552def set_info_tabs ():
@@ -500,8 +564,15 @@ def set_info_tabs():
500564def run_streamlit ():
501565
502566 # config page
503- st .set_page_config (page_title = "AgentLab Controller" , page_icon = "🎮" , layout = "wide" , initial_sidebar_state = "collapsed" )
504- st .markdown ('<h1 style="text-align: center;">🎮 AgentLab Controller 🎮</h1>' , unsafe_allow_html = True )
567+ st .set_page_config (
568+ page_title = "AgentLab Controller" ,
569+ page_icon = "🎮" ,
570+ layout = "wide" ,
571+ initial_sidebar_state = "collapsed" ,
572+ )
573+ st .markdown (
574+ '<h1 style="text-align: center;">🎮 AgentLab Controller 🎮</h1>' , unsafe_allow_html = True
575+ )
505576
506577 setup_sidebar ()
507578
0 commit comments