44import  json 
55import  logging 
66import  os 
7+ import  pickle 
78from  collections  import  Counter 
89from  datetime  import  datetime 
910from  io  import  BytesIO 
11+ from  pathlib  import  Path 
1012
1113import  numpy  as  np 
1214import  PIL .Image 
1315import  requests 
1416import  streamlit  as  st 
1517from  agentlab .agents .generic_agent  import  __all__  as  ALL_AGENTS 
1618from  agentlab .experiments .exp_utils  import  RESULTS_DIR 
19+ from  agentlab .experiments .loop  import  ExpArgs , StepInfo , save_package_versions 
1720from  agentlab .llm .llm_utils  import  Discussion 
1821from  bgym  import  DEFAULT_BENCHMARKS 
1922from  dotenv  import  load_dotenv 
@@ -133,6 +136,12 @@ def reset_env_history():
133136    st .session_state .screenshot_history  =  []
134137    st .session_state .axtree_history  =  []
135138
139+     # related to env info 
140+     st .session_state .reward_history  =  []
141+     st .session_state .terminated_history  =  []
142+     st .session_state .truncated_history  =  []
143+     st .session_state .env_info_history  =  []
144+ 
136145
137146def  reset_agent_history ():
138147    logger .info ("Resetting agent history" )
@@ -150,13 +159,19 @@ def reset_agent_state():
150159    st .session_state .agent .reset ()
151160
152161
153- def  step_env_history (obs ):
162+ def  step_env_history (obs ,  response_json ):
154163    logger .info ("Stepping env history" )
155164    st .session_state .last_obs  =  copy .deepcopy (obs )
156165    st .session_state .obs_history .append (obs )
157166    st .session_state .screenshot_history .append (obs [Constants .SCREENSHOT ])
158167    st .session_state .axtree_history .append (obs [Constants .AXTREE_TXT ])
159168
169+     # other relevant info found in response_json 
170+     st .session_state .reward_history .append (response_json ["reward" ])
171+     st .session_state .terminated_history .append (response_json ["terminated" ])
172+     st .session_state .truncated_history .append (response_json ["truncated" ])
173+     st .session_state .env_info_history .append (response_json ["info" ])
174+ 
160175
161176def  step_agent_history (action , action_info ):
162177    logger .info ("Stepping agent history" )
@@ -185,6 +200,12 @@ def revert_env_history():
185200    st .session_state .screenshot_history .pop ()
186201    st .session_state .axtree_history .pop ()
187202
203+     # related to env info 
204+     st .session_state .reward_history .pop ()
205+     st .session_state .terminated_history .pop ()
206+     st .session_state .truncated_history .pop ()
207+     st .session_state .env_info_history .pop ()
208+ 
188209
189210def  revert_agent_history ():
190211    logger .info ("Reverting agent history" )
@@ -209,6 +230,12 @@ def restore_env_history(step: int):
209230    st .session_state .screenshot_history  =  copy .deepcopy (st .session_state .screenshot_history [:step ])
210231    st .session_state .axtree_history  =  copy .deepcopy (st .session_state .axtree_history [:step ])
211232
233+     # related to env info 
234+     st .session_state .reward_history  =  copy .deepcopy (st .session_state .reward_history [:step ])
235+     st .session_state .terminated_history  =  copy .deepcopy (st .session_state .terminated_history [:step ])
236+     st .session_state .truncated_history  =  copy .deepcopy (st .session_state .truncated_history [:step ])
237+     st .session_state .env_info_history  =  copy .deepcopy (st .session_state .env_info_history [:step ])
238+ 
212239
213240def  restore_agent_history (step : int ):
214241    logger .info (f"Restoring agent history to step { step }  " )
@@ -262,6 +289,8 @@ def set_session_state():
262289        st .session_state .task  =  None 
263290    if  "subtask"  not  in   st .session_state :
264291        st .session_state .subtask  =  None 
292+     if  "env_args"  not  in   st .session_state :
293+         st .session_state .env_args  =  None 
265294
266295    # current state 
267296    if  "agent"  not  in   st .session_state :
@@ -290,6 +319,14 @@ def set_session_state():
290319        st .session_state .action_info_history  =  None 
291320    if  "obs_history"  not  in   st .session_state :
292321        st .session_state .obs_history  =  None 
322+     if  "reward_history"  not  in   st .session_state :
323+         st .session_state .reward_history  =  None 
324+     if  "terminated_history"  not  in   st .session_state :
325+         st .session_state .terminated_history  =  None 
326+     if  "truncated_history"  not  in   st .session_state :
327+         st .session_state .truncated_history  =  None 
328+     if  "env_info_history"  not  in   st .session_state :
329+         st .session_state .env_info_history  =  None 
293330
294331    if  "has_clicked_prev"  not  in   st .session_state :
295332        st .session_state .has_clicked_prev  =  False 
@@ -362,6 +399,13 @@ def set_task_selector():
362399                    st .session_state .task  =  selected_task_str 
363400                    st .session_state .subtask  =  selected_subtask_str 
364401
402+                     st .session_state .env_args  =  [
403+                         elem 
404+                         for  elem  in  selected_benchmark .env_args_list 
405+                         if  elem .task_name  ==  selected_task_str 
406+                         and  str (elem .task_seed ) ==  str (selected_subtask_str )
407+                     ][0 ]
408+ 
365409                    reset_env_history ()
366410                    reset_agent_history ()
367411
@@ -423,11 +467,12 @@ def reset_environment():
423467        logger .error (resp .json ()[Constants .STATUS ])
424468        logger .error (resp .json ()[Constants .MESSAGE ])
425469    response_json  =  resp .json ()
470+     print (response_json .keys ())
426471    response_json  =  deserialize_response (response_json )
427472    obs  =  response_json [Constants .OBS ]
428473    if  st .session_state .agent .obs_preprocessor :
429474        obs  =  st .session_state .agent .obs_preprocessor (obs )
430-     step_env_history (obs )
475+     step_env_history (obs ,  response_json )
431476    st .session_state .action  =  None 
432477    st .session_state .action_info  =  None 
433478
@@ -447,7 +492,7 @@ def reload_task():
447492    obs  =  response_json [Constants .OBS ]
448493    if  st .session_state .agent .obs_preprocessor :
449494        obs  =  st .session_state .agent .obs_preprocessor (obs )
450-     step_env_history (obs )
495+     step_env_history (obs ,  response_json )
451496    st .session_state .action  =  None 
452497    st .session_state .action_info  =  None 
453498
@@ -468,7 +513,7 @@ def step_environment(action):
468513    obs  =  response_json [Constants .OBS ]
469514    if  st .session_state .agent .obs_preprocessor :
470515        obs  =  st .session_state .agent .obs_preprocessor (obs )
471-     step_env_history (obs )
516+     step_env_history (obs ,  response_json )
472517    st .session_state .action  =  None 
473518    st .session_state .action_info  =  None 
474519
@@ -880,44 +925,45 @@ def set_save_tab():
880925    save_dir  =  st .text_input ("Save Directory" , value = "~/Downloads" )
881926    save_dir  =  os .path .expanduser (save_dir )
882927    if  st .button ("Save Session State for Current Run" ):
883-         now_str  =  datetime .now ().strftime ("%Y_%m_%d__%H_%M_%S" )
884-         filename  =  f"agentlab_controller_state_{ now_str }  .json" 
885- 
886-         # prepare payload for saving 
887-         payload  =  {}
888-         payload ["timestamp" ] =  now_str 
889-         payload ["benchmark" ] =  st .session_state .benchmark 
890-         payload ["task" ] =  st .session_state .task 
891-         payload ["subtask" ] =  st .session_state .subtask 
892-         payload ["agent_args" ] =  {
893-             k : v  for  k , v  in  vars (st .session_state .agent_args ).items () if  is_json_serializable (v )
894-         }
895-         payload ["agent_flags" ] =  {
896-             k : v  for  k , v  in  vars (st .session_state .agent .flags ).items () if  is_json_serializable (v )
897-         }
898-         payload ["agent_flags" ]["obs" ] =  {
899-             k : v 
900-             for  k , v  in  vars (st .session_state .agent .flags .obs ).items ()
901-             if  is_json_serializable (v )
902-         }
903-         payload ["agent_flags" ]["action" ] =  {
904-             k : v 
905-             for  k , v  in  vars (st .session_state .agent .flags .action ).items ()
906-             if  is_json_serializable (v )
907-         }
908-         payload ["goal" ] =  st .session_state .last_obs ["goal" ]
909-         payload ["steps" ] =  []
928+         # save everything from the session in a way that is consistent 
929+         # with how experiments are saved with AgentLab 
930+ 
931+         # dir name has this format: 2025-07-14_16-46-47_tooluse-gpt-4-1-on-workarena-l1-task-name-sort 
932+         exp_dir  =  (
933+             Path (save_dir )
934+             /  f"{ datetime .now ().strftime ('%Y-%m-%d_%H-%M-%S' )}  _genericagent_{ st .session_state .agent_args .agent_name }  _on_{ st .session_state .benchmark }  _{ st .session_state .env_args .task_name }  _{ st .session_state .env_args .task_name }  _{ st .session_state .env_args .task_seed }  " 
935+         )
936+         exp_dir .mkdir (parents = True , exist_ok = True )
937+ 
938+         # save package versions 
939+         save_package_versions (exp_dir )
940+ 
941+         # create ExpArgs object 
942+         exp_args  =  ExpArgs (
943+             agent_args = st .session_state .agent_args , env_args = st .session_state .env_args 
944+         )
945+         with  open (exp_dir  /  "exp_args.pkl" , "wb" ) as  f :
946+             pickle .dump (exp_args , f )
947+ 
948+         # create StepInfo object for each step 
910949        for  i  in  range (len (st .session_state .action_history )):
911-             step  =  {}
912-             step ["action" ] =  st .session_state .action_history [i ]
913-             step ["thought" ] =  st .session_state .thought_history [i ]
914-             step ["prompt" ] =  st .session_state .prompt_history [i ]
915-             step ["screenshot" ] =  get_base64_serialized_image (st .session_state .screenshot_history [i ])
916-             step ["axtree" ] =  st .session_state .axtree_history [i ]
917-             payload ["steps" ].append (step )
918- 
919-         with  open (os .path .join (save_dir , filename ), "w" ) as  f :
920-             json .dump (payload , f )
950+             step_info  =  StepInfo ()
951+             step_info .step  =  i 
952+             step_info .obs  =  st .session_state .obs_history [i ]
953+             step_info .reward  =  st .session_state .reward_history [i ]
954+             step_info .terminated  =  st .session_state .terminated_history [i ]
955+             step_info .truncated  =  st .session_state .truncated_history [i ]
956+             step_info .action  =  st .session_state .action_history [i ]
957+             step_info .agent_info  =  st .session_state .action_info_history [i ]
958+             step_info .make_stats ()
959+             # TODO: set profiling stats 
960+             step_info .task_info  =  st .session_state .env_info_history [i ].get ("task_info" , None )
961+             step_info .raw_reward  =  st .session_state .env_info_history [i ].get (
962+                 "RAW_REWARD_GLOBAL" , None 
963+             )
964+             step_info .save_step_info (exp_dir , save_screenshot = True , save_som = True )
965+ 
966+         st .success (f"Saved session state at { exp_dir }  " )
921967
922968
923969def  set_info_tabs ():
0 commit comments