44import json
55import logging
66import os
7+ import pickle
78from collections import Counter
89from datetime import datetime
910from io import BytesIO
11+ from pathlib import Path
1012
1113import numpy as np
1214import PIL .Image
1315import requests
1416import streamlit as st
1517from agentlab .agents .generic_agent import __all__ as ALL_AGENTS
1618from agentlab .experiments .exp_utils import RESULTS_DIR
19+ from agentlab .experiments .loop import ExpArgs , StepInfo , save_package_versions
1720from agentlab .llm .llm_utils import Discussion
1821from bgym import DEFAULT_BENCHMARKS
1922from dotenv import load_dotenv
@@ -133,6 +136,12 @@ def reset_env_history():
133136 st .session_state .screenshot_history = []
134137 st .session_state .axtree_history = []
135138
139+ # related to env info
140+ st .session_state .reward_history = []
141+ st .session_state .terminated_history = []
142+ st .session_state .truncated_history = []
143+ st .session_state .env_info_history = []
144+
136145
137146def reset_agent_history ():
138147 logger .info ("Resetting agent history" )
@@ -150,13 +159,19 @@ def reset_agent_state():
150159 st .session_state .agent .reset ()
151160
152161
153- def step_env_history (obs ):
162+ def step_env_history (obs , response_json ):
154163 logger .info ("Stepping env history" )
155164 st .session_state .last_obs = copy .deepcopy (obs )
156165 st .session_state .obs_history .append (obs )
157166 st .session_state .screenshot_history .append (obs [Constants .SCREENSHOT ])
158167 st .session_state .axtree_history .append (obs [Constants .AXTREE_TXT ])
159168
169+ # other relevant info found in response_json
170+ st .session_state .reward_history .append (response_json ["reward" ])
171+ st .session_state .terminated_history .append (response_json ["terminated" ])
172+ st .session_state .truncated_history .append (response_json ["truncated" ])
173+ st .session_state .env_info_history .append (response_json ["info" ])
174+
160175
161176def step_agent_history (action , action_info ):
162177 logger .info ("Stepping agent history" )
@@ -185,6 +200,12 @@ def revert_env_history():
185200 st .session_state .screenshot_history .pop ()
186201 st .session_state .axtree_history .pop ()
187202
203+ # related to env info
204+ st .session_state .reward_history .pop ()
205+ st .session_state .terminated_history .pop ()
206+ st .session_state .truncated_history .pop ()
207+ st .session_state .env_info_history .pop ()
208+
188209
189210def revert_agent_history ():
190211 logger .info ("Reverting agent history" )
@@ -209,6 +230,12 @@ def restore_env_history(step: int):
209230 st .session_state .screenshot_history = copy .deepcopy (st .session_state .screenshot_history [:step ])
210231 st .session_state .axtree_history = copy .deepcopy (st .session_state .axtree_history [:step ])
211232
233+ # related to env info
234+ st .session_state .reward_history = copy .deepcopy (st .session_state .reward_history [:step ])
235+ st .session_state .terminated_history = copy .deepcopy (st .session_state .terminated_history [:step ])
236+ st .session_state .truncated_history = copy .deepcopy (st .session_state .truncated_history [:step ])
237+ st .session_state .env_info_history = copy .deepcopy (st .session_state .env_info_history [:step ])
238+
212239
213240def restore_agent_history (step : int ):
214241 logger .info (f"Restoring agent history to step { step } " )
@@ -262,6 +289,8 @@ def set_session_state():
262289 st .session_state .task = None
263290 if "subtask" not in st .session_state :
264291 st .session_state .subtask = None
292+ if "env_args" not in st .session_state :
293+ st .session_state .env_args = None
265294
266295 # current state
267296 if "agent" not in st .session_state :
@@ -290,6 +319,14 @@ def set_session_state():
290319 st .session_state .action_info_history = None
291320 if "obs_history" not in st .session_state :
292321 st .session_state .obs_history = None
322+ if "reward_history" not in st .session_state :
323+ st .session_state .reward_history = None
324+ if "terminated_history" not in st .session_state :
325+ st .session_state .terminated_history = None
326+ if "truncated_history" not in st .session_state :
327+ st .session_state .truncated_history = None
328+ if "env_info_history" not in st .session_state :
329+ st .session_state .env_info_history = None
293330
294331 if "has_clicked_prev" not in st .session_state :
295332 st .session_state .has_clicked_prev = False
@@ -362,6 +399,13 @@ def set_task_selector():
362399 st .session_state .task = selected_task_str
363400 st .session_state .subtask = selected_subtask_str
364401
402+ st .session_state .env_args = [
403+ elem
404+ for elem in selected_benchmark .env_args_list
405+ if elem .task_name == selected_task_str
406+ and str (elem .task_seed ) == str (selected_subtask_str )
407+ ][0 ]
408+
365409 reset_env_history ()
366410 reset_agent_history ()
367411
@@ -423,11 +467,12 @@ def reset_environment():
423467 logger .error (resp .json ()[Constants .STATUS ])
424468 logger .error (resp .json ()[Constants .MESSAGE ])
425469 response_json = resp .json ()
470+ print (response_json .keys ())
426471 response_json = deserialize_response (response_json )
427472 obs = response_json [Constants .OBS ]
428473 if st .session_state .agent .obs_preprocessor :
429474 obs = st .session_state .agent .obs_preprocessor (obs )
430- step_env_history (obs )
475+ step_env_history (obs , response_json )
431476 st .session_state .action = None
432477 st .session_state .action_info = None
433478
@@ -447,7 +492,7 @@ def reload_task():
447492 obs = response_json [Constants .OBS ]
448493 if st .session_state .agent .obs_preprocessor :
449494 obs = st .session_state .agent .obs_preprocessor (obs )
450- step_env_history (obs )
495+ step_env_history (obs , response_json )
451496 st .session_state .action = None
452497 st .session_state .action_info = None
453498
@@ -468,7 +513,7 @@ def step_environment(action):
468513 obs = response_json [Constants .OBS ]
469514 if st .session_state .agent .obs_preprocessor :
470515 obs = st .session_state .agent .obs_preprocessor (obs )
471- step_env_history (obs )
516+ step_env_history (obs , response_json )
472517 st .session_state .action = None
473518 st .session_state .action_info = None
474519
@@ -880,44 +925,45 @@ def set_save_tab():
880925 save_dir = st .text_input ("Save Directory" , value = "~/Downloads" )
881926 save_dir = os .path .expanduser (save_dir )
882927 if st .button ("Save Session State for Current Run" ):
883- now_str = datetime .now ().strftime ("%Y_%m_%d__%H_%M_%S" )
884- filename = f"agentlab_controller_state_{ now_str } .json"
885-
886- # prepare payload for saving
887- payload = {}
888- payload ["timestamp" ] = now_str
889- payload ["benchmark" ] = st .session_state .benchmark
890- payload ["task" ] = st .session_state .task
891- payload ["subtask" ] = st .session_state .subtask
892- payload ["agent_args" ] = {
893- k : v for k , v in vars (st .session_state .agent_args ).items () if is_json_serializable (v )
894- }
895- payload ["agent_flags" ] = {
896- k : v for k , v in vars (st .session_state .agent .flags ).items () if is_json_serializable (v )
897- }
898- payload ["agent_flags" ]["obs" ] = {
899- k : v
900- for k , v in vars (st .session_state .agent .flags .obs ).items ()
901- if is_json_serializable (v )
902- }
903- payload ["agent_flags" ]["action" ] = {
904- k : v
905- for k , v in vars (st .session_state .agent .flags .action ).items ()
906- if is_json_serializable (v )
907- }
908- payload ["goal" ] = st .session_state .last_obs ["goal" ]
909- payload ["steps" ] = []
928+ # save everything from the session in a way that is consistent
929+ # with how experiments are saved with AgentLab
930+
931+ # dir name has this format: 2025-07-14_16-46-47_tooluse-gpt-4-1-on-workarena-l1-task-name-sort
932+ exp_dir = (
933+ Path (save_dir )
934+ / f"{ datetime .now ().strftime ('%Y-%m-%d_%H-%M-%S' )} _genericagent_{ st .session_state .agent_args .agent_name } _on_{ st .session_state .benchmark } _{ st .session_state .env_args .task_name } _{ st .session_state .env_args .task_name } _{ st .session_state .env_args .task_seed } "
935+ )
936+ exp_dir .mkdir (parents = True , exist_ok = True )
937+
938+ # save package versions
939+ save_package_versions (exp_dir )
940+
941+ # create ExpArgs object
942+ exp_args = ExpArgs (
943+ agent_args = st .session_state .agent_args , env_args = st .session_state .env_args
944+ )
945+ with open (exp_dir / "exp_args.pkl" , "wb" ) as f :
946+ pickle .dump (exp_args , f )
947+
948+ # create StepInfo object for each step
910949 for i in range (len (st .session_state .action_history )):
911- step = {}
912- step ["action" ] = st .session_state .action_history [i ]
913- step ["thought" ] = st .session_state .thought_history [i ]
914- step ["prompt" ] = st .session_state .prompt_history [i ]
915- step ["screenshot" ] = get_base64_serialized_image (st .session_state .screenshot_history [i ])
916- step ["axtree" ] = st .session_state .axtree_history [i ]
917- payload ["steps" ].append (step )
918-
919- with open (os .path .join (save_dir , filename ), "w" ) as f :
920- json .dump (payload , f )
950+ step_info = StepInfo ()
951+ step_info .step = i
952+ step_info .obs = st .session_state .obs_history [i ]
953+ step_info .reward = st .session_state .reward_history [i ]
954+ step_info .terminated = st .session_state .terminated_history [i ]
955+ step_info .truncated = st .session_state .truncated_history [i ]
956+ step_info .action = st .session_state .action_history [i ]
957+ step_info .agent_info = st .session_state .action_info_history [i ]
958+ step_info .make_stats ()
959+ # TODO: set profiling stats
960+ step_info .task_info = st .session_state .env_info_history [i ].get ("task_info" , None )
961+ step_info .raw_reward = st .session_state .env_info_history [i ].get (
962+ "RAW_REWARD_GLOBAL" , None
963+ )
964+ step_info .save_step_info (exp_dir , save_screenshot = True , save_som = True )
965+
966+ st .success (f"Saved session state at { exp_dir } " )
921967
922968
923969def set_info_tabs ():
0 commit comments